• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
18 
19 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/platform/notification.h"
28 #include "tensorflow/core/platform/thread_annotations.h"
29 #include "tensorflow/core/public/version.h"
30 
31 namespace tensorflow {
32 
33 class XlaContext;
34 
35 // The XlaCompiler class is responsible for compilation of a self-contained
36 // subgraph of a TensorFlow computation using the XLA linear algebra runtime.
37 // It does a symbolic execution of the graph starting from specific input
38 // shapes, using a JIT device to convert operators into XLA computations.
39 //
40 // XlaCompiler is typically invoked from an `_XlaLaunch` operator once the
41 // shapes of all input parameters to the computation are known. This is
42 // because the symbolic execution requires known shapes for all operations.
43 //
44 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes,
45 // and return outputs via _Retval nodes.
46 //
47 // The XlaCompiler requires one Argument struct for each _Arg index, that
48 // describes each argument. Arguments can be compile-time constants
49 // (kind kConstant), run-time parameters (kind kParameter), or resources
50 // (kind kResource).
51 //
52 // Only kParameter and initialized kResource arguments become runtime parameters
53 // to the generated XLA computation. The XLA computation will have run-time
54 // parameters in the following order:
55 //   +---------------------+-----------------------------------------+
56 //   |  kParameter values  |  Initial values of kResource arguments  |
57 //   +---------------------+-----------------------------------------+
58 // Within each block, the arguments are arranged by the _Arg index from which
59 // they were derived.
60 //
61 // The run-time outputs of the XLA computation are arranged in the following
62 // order:
63 //   +------------------+-----------------------------------------+
64 //   |  _Retval values  |  Updated values of kResource arguments  |
65 //   +------------------+-----------------------------------------+
66 // _Retval values are ordered by _Retval index, whereas kResource values are
67 // ordered by the original _Arg position of the variable.
68 //
69 // In both inputs and outputs, kResource values are placed the end. When
70 // emitting While loop bodies, we must ensure that the loop body has
71 // identical input and output signatures. By moving variable values
72 // to the end of the argument list and using the
73 // `return_updated_values_for_all_variables` option, we can ensure that the
74 // input and output values of resources appear at the same positions.
75 //
76 // Resources are passed as parameters or returned as resource updates in
77 // "packed" form.
78 // kStack resources are packed as (array, size of stack) XLA tuples.
79 // kTensorArray resources without gradients are packed as the array that
80 // backs the TensorArray. If gradients are present (`tensor_array_gradients`),
81 // the packed representation is a (array, gradient0, gradient1, ...) tuple,
82 // where gradient_k is the value of the k-th gradient in the
83 // `tensor_array_gradients` ordered set.
84 class XlaCompiler {
85  public:
86   // Describes how to derive the value of each _Arg node in the graph/function
87   // being compiled. There must be one Argument for each _Arg index.
88   struct Argument {
89     enum Kind {
90       // Default value; not a valid kind.
91       kInvalid,
92 
93       // Argument is a compile-time constant. No associated runtime parameter.
94       kConstant,
95 
96       // Argument is a Variable, TensorArray, or Stack resource. Has an
97       // associated runtime parameter iff `initialized` is true.
98       kResource,
99 
100       // Argument is a run-time parameter.
101       kParameter,
102     };
103 
104     Kind kind = kInvalid;
105 
106     // The type of the argument. If the argument is a resource, this
107     // is the type of the variable's value, not DT_RESOURCE.
108     DataType type;
109 
110     // The shape of the argument. For:
111     // * a parameter: the shape of the parameter.
112     // * a constant: ignored; the shape given by constant_value is used
113     //     instead.
114     // * an uninitialized resource: ignored. We don't yet know the shape of an
115     //     uninitialized resource (otherwise we would have initialized it!)
116     // * an initialized variable: the shape of the variable's value.
117     // * an initialized TensorArray or Stack resource: the shape of an entry in
118     //   the TensorArray/Stack. Note this is the size of a single entry, not the
119     //   XLA data structure that represents the complete stack/array.
120     TensorShape shape;
121 
122     // The value of the argument, if it is a compile-time constant. Must be a
123     // host-memory tensor.
124     Tensor constant_value;
125 
126     // The name of this argument, used for debugging.
127     string name;
128 
129     // For a kResource, what kind of resource is it?
130     XlaResource::Kind resource_kind = XlaResource::kInvalid;
131 
132     // For a kResource, has this resource been initialized?
133     bool initialized = false;
134 
135     // For a TensorArray or Stack resource, what is the array's declared size?
136     // (Used for lazy initialization.)
137     int64 tensor_array_size = -1;
138 
139     // TensorArray resource parameters are passed as (array, gradient array 0,
140     // ..., gradient array k), where the gradient arrays are in the same order
141     // as `tensor_array_gradients`.
142     std::set<string> tensor_array_gradients;
143 
144     bool operator==(const Argument& other) const;
145   };
146 
147   // Options pertaining to an individual call to CompileGraph() or
148   // CompileFunction().
149   struct CompileOptions {
150     // If `use_tuple_arg` is true, a single tuple parameter will be used for all
151     // arguments; if false, each argument gets its own parameter.
152     bool use_tuple_arg = false;
153 
154     // If 'return_updated_values_for_all_resources' is true, then updated
155     // values of all resource arguments will be included in the
156     // 'resource_updates' of the computation, even if the resource was not
157     // modified by the computation. Used when compiling loop bodies to ensure
158     // the input and output signatures match.
159     bool return_updated_values_for_all_resources = false;
160 
161     // If 'resolve_compile_time_constants' is true, then outputs of a
162     // computation that are known to be compile-time constants will be returned
163     // as Tensors at compile-time, rather than as run-time outputs of the
164     // computation.
165     bool resolve_compile_time_constants = true;
166 
167     // True when compiling the entry computation, false for subcomputations
168     // (while, call, etc.)
169     bool is_entry_computation = true;
170   };
171 
172   struct OutputDescription {
173     // Type and shape of the output.
174     DataType type;
175     TensorShape shape;
176 
177     // Constant output value, if known to be constant at JIT compilation time.
178     // 'Tensor' is in host memory.
179     bool is_constant = false;
180     Tensor constant_value;
181   };
182 
183   // Describes a variable write side effect of the computation.
184   struct ResourceUpdate {
185     // Index of the input that contains the variable resource to write to.
186     int input_index;
187 
188     // Type and shape of the tensor to be written back.
189     // The `shape` field has the same meaning as the Argument::shape field.
190     DataType type;
191     TensorShape shape;
192 
193     // Was the value of the variable modified by the computation?
194     // (Always true, unless `return_updated_values_for_all_resources` is true.)
195     bool modified;
196 
197     // If the resource is a TensorArray, the set of gradients read or written.
198     std::set<string> tensor_array_gradients_accessed;
199   };
200 
201   struct CompilationResult {
202     // Vector that maps from the parameters of the XLA computation to their
203     // original argument positions. To handle compile-time constant inputs and
204     // resources, the parameters to the XLA computation may be a subset of the
205     // original arguments, and are not necessarily in the same order.)
206     std::vector<int> input_mapping;
207 
208     // Input shapes of the computation.
209     std::vector<xla::Shape> xla_input_shapes;
210 
211     // Output shape in XLA format. The output shape is always a tuple.
212     xla::Shape xla_output_shape;
213 
214     // TensorFlow shapes of outputs, together with the values of any
215     // constant arguments. Vector indexed by Tensorflow _Retval number,
216     // containing both constant and non-constant results.
217     std::vector<OutputDescription> outputs;
218 
219     // Resources whose values were updated by the computation, ordered
220     // by return value position. Resource updates follow the non-constant
221     // results in the outputs of XLA computation.
222     std::vector<ResourceUpdate> resource_updates;
223 
224     // The XLA computation built from the tensorflow subgraph.
225     std::shared_ptr<xla::Computation> computation;
226   };
227 
228   struct Options {
229     // Name of the compilation device to use. Needs to be live only during
230     // XlaCompiler's constructor.
231     const DeviceType* device_type = nullptr;
232 
233     xla::Client* client = nullptr;
234 
235     // Function library in which to find function definitions. Must be non-null.
236     const FunctionLibraryDefinition* flib_def = nullptr;
237 
238     // The graph def version to be compiled.
239     int graph_def_version = TF_GRAPH_DEF_VERSION;
240 
241     // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
242     // for CPU.
243     bool allow_cpu_custom_calls = false;
244 
245     // If set, the XLA representation of variables represented to XLA as the
246     // shape given by this shape function. Variables are reshaped to this shape
247     // on write, and reshaped to their original shape on read.
248     std::function<TensorShape(const TensorShape&, DataType)>
249         variable_representation_shape_fn;
250 
251     // If not nullptr, populate_resource_manager is called with the
252     // compilation device's resource manager when the compilation
253     // device is created, and can be used to create metadata objects
254     // that can be accessed by XLA op kernels.
255     std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
256 
257     // If not nullptr, this memory allocator can be used by the compiler for
258     // temporary allocations it might want to make during compilation.
259     //
260     // For example, the compiler may want to try out different algorithms and
261     // choose the fastest one, and it might run those algorithms over buffers
262     // created using this allocator.
263     //
264     // The compiler can function correctly without an explicit allocator given
265     // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
266     // allocate most or all available memory on the device, leaving none for the
267     // compiler to access, unless it can use TensorFlow's allocator.
268     xla::DeviceMemoryAllocator* device_allocator = nullptr;
269   };
270 
271   explicit XlaCompiler(Options options);
272 
273   ~XlaCompiler();
274 
275   Status CompileFunction(const CompileOptions& options,
276                          const NameAttrList& fn_name_attrs,
277                          std::vector<Argument> args, CompilationResult* result);
278 
279   // Compiles a tensorflow::Graph into an xla::Computation.
280   // Similar to CompileFunction, but takes a Graph as input rather than a
281   // function.
282   Status CompileGraph(const CompileOptions& options, string const& name,
283                       std::unique_ptr<Graph> graph,
284                       const std::vector<Argument>& args,
285                       CompilationResult* result);
286 
287   // Returns the shape of the XLA parameter for an argument 'arg'.
288   // See the class comment for more details about the argument passing
289   // convention.
290   Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
291 
292   // Retrieves the channel handle associated with `key`. Allocates
293   // a new channel handle if none exists.
294   // Channel handles can be used to communicate between different
295   // computations. Computations that communicate should be compiled with the
296   // same XlaCompiler.
297   Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
298 
options()299   const Options& options() const { return options_; }
client()300   xla::Client* client() const { return options_.client; }
flib_runtime()301   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
302 
303  private:
304   // Sets the function body `fbody` to the one registered as `function`.
305   Status FindFunctionBody(const NameAttrList& function,
306                           const FunctionBody** fbody);
307 
308   // Returns the optimized graph object in this function body.
309   std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
310 
311   // Builds XLA computations for each of the arguments to the computation.
312   // `args` are the arguments to the computation.
313   Status BuildArguments(const Graph& graph,
314                         const std::vector<XlaCompiler::Argument>& args,
315                         bool use_tuple_arg, xla::ComputationBuilder* builder,
316                         XlaContext* context, std::vector<int>* arg_cores,
317                         std::vector<XlaExpression>* arg_expressions,
318                         std::vector<int>* input_mapping,
319                         std::vector<xla::Shape>* input_shapes,
320                         bool is_entry_computation);
321 
322   // Graph compiler needs to know how to get an optimized graph from a function
323   // body.
324   friend class GraphCompiler;
325   friend class XlaCompilerTest;
326 
327   Options options_;
328 
329   // Status set to non-OK in the constructor if initialization fails.
330   Status initialization_status_;
331 
332   // Returns the next step sequence number.
333   int64 NextStepId();
334 
335   // Internal sequence number for steps executed on the compilation device.
336   int64 next_step_id_;
337 
338   XlaCompilationDevice* device_;  // Owned by device_mgr_
339   DeviceMgr device_mgr_;
340 
341   // To avoid copying the client's function library, use a local function
342   // library and runtime for functions created as part of the functionalize
343   // control flow transformation.
344   std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
345   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
346   std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
347 
348   FunctionLibraryRuntime* local_flib_runtime_;  // owned by local_pflr_.
349   FunctionLibraryRuntime* flib_runtime_;        // owned by pflr_.
350 
351   struct SignatureHash {
352     uint64 operator()(
353         const std::pair<string, std::vector<Argument>>& signature) const;
354   };
355 
356   std::unordered_map<std::pair<string, std::vector<Argument>>,
357                      CompilationResult, SignatureHash>
358       cache_;
359 
360   std::unordered_map<string, xla::ChannelHandle> channels_;
361 
362   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
363 };
364 
365 }  // namespace tensorflow
366 
367 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
368