1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 18 19 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 20 #include "tensorflow/compiler/xla/client/local_client.h" 21 #include "tensorflow/core/common_runtime/device.h" 22 #include "tensorflow/core/common_runtime/device_mgr.h" 23 #include "tensorflow/core/common_runtime/function.h" 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/platform/env.h" 26 #include "tensorflow/core/platform/mutex.h" 27 #include "tensorflow/core/platform/notification.h" 28 #include "tensorflow/core/platform/thread_annotations.h" 29 #include "tensorflow/core/public/version.h" 30 31 namespace tensorflow { 32 33 class XlaContext; 34 35 // The XlaCompiler class is responsible for compilation of a self-contained 36 // subgraph of a TensorFlow computation using the XLA linear algebra runtime. 37 // It does a symbolic execution of the graph starting from specific input 38 // shapes, using a JIT device to convert operators into XLA computations. 39 // 40 // XlaCompiler is typically invoked from an `_XlaLaunch` operator once the 41 // shapes of all input parameters to the computation are known. This is 42 // because the symbolic execution requires known shapes for all operations. 43 // 44 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes, 45 // and return outputs via _Retval nodes. 46 // 47 // The XlaCompiler requires one Argument struct for each _Arg index, that 48 // describes each argument. Arguments can be compile-time constants 49 // (kind kConstant), run-time parameters (kind kParameter), or resources 50 // (kind kResource). 51 // 52 // Only kParameter and initialized kResource arguments become runtime parameters 53 // to the generated XLA computation. The XLA computation will have run-time 54 // parameters in the following order: 55 // +---------------------+-----------------------------------------+ 56 // | kParameter values | Initial values of kResource arguments | 57 // +---------------------+-----------------------------------------+ 58 // Within each block, the arguments are arranged by the _Arg index from which 59 // they were derived. 60 // 61 // The run-time outputs of the XLA computation are arranged in the following 62 // order: 63 // +------------------+-----------------------------------------+ 64 // | _Retval values | Updated values of kResource arguments | 65 // +------------------+-----------------------------------------+ 66 // _Retval values are ordered by _Retval index, whereas kResource values are 67 // ordered by the original _Arg position of the variable. 68 // 69 // In both inputs and outputs, kResource values are placed the end. When 70 // emitting While loop bodies, we must ensure that the loop body has 71 // identical input and output signatures. By moving variable values 72 // to the end of the argument list and using the 73 // `return_updated_values_for_all_variables` option, we can ensure that the 74 // input and output values of resources appear at the same positions. 75 // 76 // Resources are passed as parameters or returned as resource updates in 77 // "packed" form. 78 // kStack resources are packed as (array, size of stack) XLA tuples. 79 // kTensorArray resources without gradients are packed as the array that 80 // backs the TensorArray. If gradients are present (`tensor_array_gradients`), 81 // the packed representation is a (array, gradient0, gradient1, ...) tuple, 82 // where gradient_k is the value of the k-th gradient in the 83 // `tensor_array_gradients` ordered set. 84 class XlaCompiler { 85 public: 86 // Describes how to derive the value of each _Arg node in the graph/function 87 // being compiled. There must be one Argument for each _Arg index. 88 struct Argument { 89 enum Kind { 90 // Default value; not a valid kind. 91 kInvalid, 92 93 // Argument is a compile-time constant. No associated runtime parameter. 94 kConstant, 95 96 // Argument is a Variable, TensorArray, or Stack resource. Has an 97 // associated runtime parameter iff `initialized` is true. 98 kResource, 99 100 // Argument is a run-time parameter. 101 kParameter, 102 }; 103 104 Kind kind = kInvalid; 105 106 // The type of the argument. If the argument is a resource, this 107 // is the type of the variable's value, not DT_RESOURCE. 108 DataType type; 109 110 // The shape of the argument. For: 111 // * a parameter: the shape of the parameter. 112 // * a constant: ignored; the shape given by constant_value is used 113 // instead. 114 // * an uninitialized resource: ignored. We don't yet know the shape of an 115 // uninitialized resource (otherwise we would have initialized it!) 116 // * an initialized variable: the shape of the variable's value. 117 // * an initialized TensorArray or Stack resource: the shape of an entry in 118 // the TensorArray/Stack. Note this is the size of a single entry, not the 119 // XLA data structure that represents the complete stack/array. 120 TensorShape shape; 121 122 // The value of the argument, if it is a compile-time constant. Must be a 123 // host-memory tensor. 124 Tensor constant_value; 125 126 // The name of this argument, used for debugging. 127 string name; 128 129 // For a kResource, what kind of resource is it? 130 XlaResource::Kind resource_kind = XlaResource::kInvalid; 131 132 // For a kResource, has this resource been initialized? 133 bool initialized = false; 134 135 // For a TensorArray or Stack resource, what is the array's declared size? 136 // (Used for lazy initialization.) 137 int64 tensor_array_size = -1; 138 139 // TensorArray resource parameters are passed as (array, gradient array 0, 140 // ..., gradient array k), where the gradient arrays are in the same order 141 // as `tensor_array_gradients`. 142 std::set<string> tensor_array_gradients; 143 144 bool operator==(const Argument& other) const; 145 }; 146 147 // Options pertaining to an individual call to CompileGraph() or 148 // CompileFunction(). 149 struct CompileOptions { 150 // If `use_tuple_arg` is true, a single tuple parameter will be used for all 151 // arguments; if false, each argument gets its own parameter. 152 bool use_tuple_arg = false; 153 154 // If 'return_updated_values_for_all_resources' is true, then updated 155 // values of all resource arguments will be included in the 156 // 'resource_updates' of the computation, even if the resource was not 157 // modified by the computation. Used when compiling loop bodies to ensure 158 // the input and output signatures match. 159 bool return_updated_values_for_all_resources = false; 160 161 // If 'resolve_compile_time_constants' is true, then outputs of a 162 // computation that are known to be compile-time constants will be returned 163 // as Tensors at compile-time, rather than as run-time outputs of the 164 // computation. 165 bool resolve_compile_time_constants = true; 166 167 // True when compiling the entry computation, false for subcomputations 168 // (while, call, etc.) 169 bool is_entry_computation = true; 170 }; 171 172 struct OutputDescription { 173 // Type and shape of the output. 174 DataType type; 175 TensorShape shape; 176 177 // Constant output value, if known to be constant at JIT compilation time. 178 // 'Tensor' is in host memory. 179 bool is_constant = false; 180 Tensor constant_value; 181 }; 182 183 // Describes a variable write side effect of the computation. 184 struct ResourceUpdate { 185 // Index of the input that contains the variable resource to write to. 186 int input_index; 187 188 // Type and shape of the tensor to be written back. 189 // The `shape` field has the same meaning as the Argument::shape field. 190 DataType type; 191 TensorShape shape; 192 193 // Was the value of the variable modified by the computation? 194 // (Always true, unless `return_updated_values_for_all_resources` is true.) 195 bool modified; 196 197 // If the resource is a TensorArray, the set of gradients read or written. 198 std::set<string> tensor_array_gradients_accessed; 199 }; 200 201 struct CompilationResult { 202 // Vector that maps from the parameters of the XLA computation to their 203 // original argument positions. To handle compile-time constant inputs and 204 // resources, the parameters to the XLA computation may be a subset of the 205 // original arguments, and are not necessarily in the same order.) 206 std::vector<int> input_mapping; 207 208 // Input shapes of the computation. 209 std::vector<xla::Shape> xla_input_shapes; 210 211 // Output shape in XLA format. The output shape is always a tuple. 212 xla::Shape xla_output_shape; 213 214 // TensorFlow shapes of outputs, together with the values of any 215 // constant arguments. Vector indexed by Tensorflow _Retval number, 216 // containing both constant and non-constant results. 217 std::vector<OutputDescription> outputs; 218 219 // Resources whose values were updated by the computation, ordered 220 // by return value position. Resource updates follow the non-constant 221 // results in the outputs of XLA computation. 222 std::vector<ResourceUpdate> resource_updates; 223 224 // The XLA computation built from the tensorflow subgraph. 225 std::shared_ptr<xla::Computation> computation; 226 }; 227 228 struct Options { 229 // Name of the compilation device to use. Needs to be live only during 230 // XlaCompiler's constructor. 231 const DeviceType* device_type = nullptr; 232 233 xla::Client* client = nullptr; 234 235 // Function library in which to find function definitions. Must be non-null. 236 const FunctionLibraryDefinition* flib_def = nullptr; 237 238 // The graph def version to be compiled. 239 int graph_def_version = TF_GRAPH_DEF_VERSION; 240 241 // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() 242 // for CPU. 243 bool allow_cpu_custom_calls = false; 244 245 // If set, the XLA representation of variables represented to XLA as the 246 // shape given by this shape function. Variables are reshaped to this shape 247 // on write, and reshaped to their original shape on read. 248 std::function<TensorShape(const TensorShape&, DataType)> 249 variable_representation_shape_fn; 250 251 // If not nullptr, populate_resource_manager is called with the 252 // compilation device's resource manager when the compilation 253 // device is created, and can be used to create metadata objects 254 // that can be accessed by XLA op kernels. 255 std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr; 256 257 // If not nullptr, this memory allocator can be used by the compiler for 258 // temporary allocations it might want to make during compilation. 259 // 260 // For example, the compiler may want to try out different algorithms and 261 // choose the fastest one, and it might run those algorithms over buffers 262 // created using this allocator. 263 // 264 // The compiler can function correctly without an explicit allocator given 265 // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly 266 // allocate most or all available memory on the device, leaving none for the 267 // compiler to access, unless it can use TensorFlow's allocator. 268 xla::DeviceMemoryAllocator* device_allocator = nullptr; 269 }; 270 271 explicit XlaCompiler(Options options); 272 273 ~XlaCompiler(); 274 275 Status CompileFunction(const CompileOptions& options, 276 const NameAttrList& fn_name_attrs, 277 std::vector<Argument> args, CompilationResult* result); 278 279 // Compiles a tensorflow::Graph into an xla::Computation. 280 // Similar to CompileFunction, but takes a Graph as input rather than a 281 // function. 282 Status CompileGraph(const CompileOptions& options, string const& name, 283 std::unique_ptr<Graph> graph, 284 const std::vector<Argument>& args, 285 CompilationResult* result); 286 287 // Returns the shape of the XLA parameter for an argument 'arg'. 288 // See the class comment for more details about the argument passing 289 // convention. 290 Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); 291 292 // Retrieves the channel handle associated with `key`. Allocates 293 // a new channel handle if none exists. 294 // Channel handles can be used to communicate between different 295 // computations. Computations that communicate should be compiled with the 296 // same XlaCompiler. 297 Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); 298 options()299 const Options& options() const { return options_; } client()300 xla::Client* client() const { return options_.client; } flib_runtime()301 FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } 302 303 private: 304 // Sets the function body `fbody` to the one registered as `function`. 305 Status FindFunctionBody(const NameAttrList& function, 306 const FunctionBody** fbody); 307 308 // Returns the optimized graph object in this function body. 309 std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody); 310 311 // Builds XLA computations for each of the arguments to the computation. 312 // `args` are the arguments to the computation. 313 Status BuildArguments(const Graph& graph, 314 const std::vector<XlaCompiler::Argument>& args, 315 bool use_tuple_arg, xla::ComputationBuilder* builder, 316 XlaContext* context, std::vector<int>* arg_cores, 317 std::vector<XlaExpression>* arg_expressions, 318 std::vector<int>* input_mapping, 319 std::vector<xla::Shape>* input_shapes, 320 bool is_entry_computation); 321 322 // Graph compiler needs to know how to get an optimized graph from a function 323 // body. 324 friend class GraphCompiler; 325 friend class XlaCompilerTest; 326 327 Options options_; 328 329 // Status set to non-OK in the constructor if initialization fails. 330 Status initialization_status_; 331 332 // Returns the next step sequence number. 333 int64 NextStepId(); 334 335 // Internal sequence number for steps executed on the compilation device. 336 int64 next_step_id_; 337 338 XlaCompilationDevice* device_; // Owned by device_mgr_ 339 DeviceMgr device_mgr_; 340 341 // To avoid copying the client's function library, use a local function 342 // library and runtime for functions created as part of the functionalize 343 // control flow transformation. 344 std::unique_ptr<FunctionLibraryDefinition> local_flib_def_; 345 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 346 std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_; 347 348 FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_. 349 FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. 350 351 struct SignatureHash { 352 uint64 operator()( 353 const std::pair<string, std::vector<Argument>>& signature) const; 354 }; 355 356 std::unordered_map<std::pair<string, std::vector<Argument>>, 357 CompilationResult, SignatureHash> 358 cache_; 359 360 std::unordered_map<string, xla::ChannelHandle> channels_; 361 362 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); 363 }; 364 365 } // namespace tensorflow 366 367 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 368