1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 18 19 #include <stack> 20 21 #include "absl/types/span.h" 22 #include "absl/types/variant.h" 23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" 24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 25 #include "tensorflow/compiler/tf2xla/xla_expression.h" 26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 27 #include "tensorflow/compiler/xla/client/local_client.h" 28 #include "tensorflow/compiler/xla/client/xla_builder.h" 29 #include "tensorflow/compiler/xla/client/xla_computation.h" 30 #include "tensorflow/compiler/xla/status_macros.h" 31 #include "tensorflow/core/common_runtime/device.h" 32 #include "tensorflow/core/common_runtime/device_mgr.h" 33 #include "tensorflow/core/common_runtime/function.h" 34 #include "tensorflow/core/framework/function.h" 35 #include "tensorflow/core/lib/core/errors.h" 36 #include "tensorflow/core/platform/env.h" 37 #include "tensorflow/core/platform/mutex.h" 38 #include "tensorflow/core/platform/notification.h" 39 #include "tensorflow/core/platform/thread_annotations.h" 40 #include "tensorflow/core/public/version.h" 41 42 namespace tensorflow { 43 44 class XlaContext; 45 46 // The XlaCompiler class is responsible for compilation of a self-contained 47 // subgraph of a TensorFlow computation using the XLA linear algebra runtime. 48 // It does a symbolic execution of the graph starting from specific input 49 // shapes, using a JIT device to convert operators into XLA computations. 50 // 51 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the 52 // shapes of all input parameters to the computation are known. This is 53 // because the symbolic execution requires known shapes for all operations. 54 // 55 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes, 56 // and return outputs via _Retval nodes. 57 // 58 // The XlaCompiler requires one Argument struct for each _Arg index, that 59 // describes each argument. Arguments can be compile-time constants 60 // (kind kConstant), run-time parameters (kind kParameter), or resources 61 // (kind kResource). 62 // 63 // Only kParameter and initialized kResource arguments become runtime parameters 64 // to the generated XLA computation. 65 // 66 // The run-time outputs of the XLA computation are arranged in the following 67 // order: 68 // +------------------+-----------------------------------------+ 69 // | _Retval values | Updated values of kResource arguments | 70 // +------------------+-----------------------------------------+ 71 // _Retval values are ordered by _Retval index, whereas kResource values are 72 // ordered by the original _Arg position of the variable. 73 // 74 // If a shape representation function is provided as part of 75 // XlaCompiler::CompileOptions, kParameter arguments and return values to an 76 // entry computation will be reshaped in accordance to the shape function. 77 // Arguments and return values to a non-entry computation are not reshaped. 78 // Variable resource arguments are passed and returned in reshaped form, even 79 // for non-entry computations. This feature allows TensorFlow to keep on-device 80 // tensors with a different shape to their representation inside the XLA 81 // computation. 82 // 83 // In computation outputs, updated kResource values are placed the end. When 84 // emitting While loop bodies, we must ensure that the loop body has 85 // identical input and output signatures. By passing variable values 86 // at the end of the argument list and using the 87 // `return_updated_values_for_all_variables` option, we can ensure that the 88 // input and output values of resources appear at the same positions. 89 // 90 // Resources are passed as parameters or returned as resource updates in 91 // "packed" form. 92 // kStack resources are packed as (array, size of stack) XLA tuples. 93 // kTensorArray resources without gradients are packed as the array that 94 // backs the TensorArray. If gradients are present (`tensor_array_gradients`), 95 // the packed representation is a (array, gradient0, gradient1, ...) tuple, 96 // where gradient_k is the value of the k-th gradient in the 97 // `tensor_array_gradients` ordered set. 98 class XlaCompiler { 99 public: 100 // Describes how to derive the value of each _Arg node in the graph/function 101 // being compiled. There must be one Argument for each _Arg index. 102 struct Argument { 103 enum Kind { 104 // Default value; not a valid kind. 105 kInvalid, 106 107 // Argument is a compile-time constant. No associated runtime parameter. 108 kConstant, 109 110 // Argument is a Variable, TensorArray, or Stack resource. Has an 111 // associated runtime parameter iff `initialized` is true. 112 kResource, 113 114 // Argument is a run-time parameter. 115 kParameter, 116 117 // Argument is an XLA token. 118 kToken, 119 }; 120 121 Kind kind = kInvalid; 122 123 // The type of the argument. If the argument is a resource, this 124 // is the type of the variable's value, not DT_RESOURCE. 125 DataType type = DT_INVALID; 126 127 // The shape of the argument. For: 128 // * a parameter: the shape of the parameter. We allow setting the xla shape 129 // if known. This helps avoid conversions to and from TensorShape. 130 // * a constant: ignored; the shape given by constant_value is used 131 // instead. 132 // * an uninitialized resource: ignored. We don't yet know the shape of an 133 // uninitialized resource (otherwise we would have initialized it!) 134 // * an initialized variable: the shape of the variable's value. 135 // * an initialized TensorArray or Stack resource: the shape of an entry in 136 // the TensorArray/Stack. Note this is the size of a single entry, not the 137 // XLA data structure that represents the complete stack/array. 138 absl::variant<TensorShape, xla::Shape> shape; 139 140 // The value of the argument, if it is a compile-time constant. Must be a 141 // host-memory tensor. 142 Tensor constant_value; 143 144 // The name of this argument, used for debugging. 145 string name; 146 147 // For a kResource, what kind of resource is it? 148 XlaResource::Kind resource_kind = XlaResource::kInvalid; 149 150 // For a kResource, has this resource been initialized? 151 bool initialized = false; 152 153 // For a TensorArray or Stack resource, what is the array's declared size? 154 // (Used for lazy initialization.) 155 int64 max_array_size = -1; 156 157 // TensorArray resource parameters are passed as (array, gradient array 0, 158 // ..., gradient array k), where the gradient arrays are in the same order 159 // as `tensor_array_gradients`. 160 std::set<string> tensor_array_gradients; 161 162 // dynamic dims to arg number map. Empty if no dynamic shapes. 163 std::map<int32, int32> dynamic_dim_to_arg_num_map; 164 bool is_pad_arg = false; 165 166 bool operator==(const Argument& other) const; 167 168 // Returns a human-readable summary of the argument. 169 string HumanString() const; 170 171 // Returns the dimension sizes for either TensorShape or xla::Shape. 172 std::vector<int64> DimensionSizes() const; 173 174 // Returns the human-readable string for either TensorShape or xla::Shape. 175 string ShapeHumanString() const; 176 }; 177 178 // Options pertaining to an individual call to CompileGraph() or 179 // CompileFunction(). 180 struct CompileOptions { 181 // If `use_tuple_arg` is true, a single tuple parameter will be used for all 182 // arguments; if false, each argument gets its own parameter. 183 bool use_tuple_arg = false; 184 185 // If 'return_updated_values_for_all_resources' is true, then updated 186 // values of all resource arguments will be included in the 187 // 'resource_updates' of the computation, even if the resource was not 188 // modified by the computation. Used when compiling loop bodies to ensure 189 // the input and output signatures match. 190 bool return_updated_values_for_all_resources = false; 191 192 // If 'resolve_compile_time_constants' is true, then outputs of a 193 // computation that are known to be compile-time constants will be returned 194 // as Tensors at compile-time, rather than as run-time outputs of the 195 // computation. 196 bool resolve_compile_time_constants = true; 197 198 // If 'always_return_tuple' is true, then the output of a computation will 199 // always be a tuple. Otherwise, a single-element output will not be wrapped 200 // in a tuple. 201 bool always_return_tuple = true; 202 203 // True when compiling the entry computation, false for subcomputations 204 // (while, call, etc.) 205 bool is_entry_computation = true; 206 207 // True when we should add XLA input & output to the graph/function. 208 bool add_token_input_output = false; 209 }; 210 211 struct OutputDescription { 212 // Type and shape of the output. The shape is the unflattened shape. 213 // When `type` is DT_RESOURCE, `shape` is the shape of the resource 214 // variable's value. 215 DataType type; 216 TensorShape shape; 217 218 // Constant output value, if known to be constant at JIT compilation time. 219 // 'Tensor' is in host memory. 220 bool is_constant = false; 221 Tensor constant_value; 222 223 // When this output is a resource, i.e. `type == DT_RESOURCE`, this is 224 // the index of the input that contains the resource. 225 int input_index; 226 }; 227 228 // Describes a variable write side effect of the computation. 229 struct ResourceUpdate { 230 // Index of the input that contains the variable resource to write to. 231 int input_index; 232 233 // Type and shape of the tensor to be written back. 234 // The `shape` field has the same meaning as the Argument::shape field. 235 DataType type; 236 TensorShape shape; 237 238 // Was the value of the variable modified by the computation? 239 // (Always true, unless `return_updated_values_for_all_resources` is true.) 240 bool modified; 241 242 // If the resource is a TensorArray, the set of gradients read or written. 243 std::set<string> tensor_array_gradients_accessed; 244 }; 245 246 struct CompilationResult { 247 // Vector that maps from the parameters of the XLA computation to their 248 // original argument positions. To handle compile-time constant inputs, the 249 // parameters to the XLA computation may be a subset of the original 250 // arguments. The relative ordering of parameters are maintained. 251 std::vector<int> input_mapping; 252 253 // Input shapes of the computation. If we are flattening inputs, these are 254 // the flattened shapes. 255 std::vector<xla::Shape> xla_input_shapes; 256 257 // Output shape in XLA format. The output shape is always a tuple. If we 258 // are flattening outputs, these are the flattened shapes. 259 xla::Shape xla_output_shape; 260 261 // TensorFlow shapes of outputs, together with the values of any 262 // constant arguments. Vector indexed by Tensorflow _Retval number, 263 // containing both constant and non-constant results. 264 std::vector<OutputDescription> outputs; 265 266 // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their 267 // matching RecvAtHost/SendFromHost Ops in the outer graph. 268 tf2xla::HostComputeMetadata host_compute_metadata; 269 270 // Resources whose values were updated by the computation, ordered 271 // by return value position (which is the same as the order the resources 272 // were passed as arguments). Resource updates follow the non-constant 273 // results in the outputs of XLA computation. 274 std::vector<ResourceUpdate> resource_updates; 275 276 // The XLA computation built from the tensorflow subgraph. 277 std::shared_ptr<xla::XlaComputation> computation; 278 }; 279 280 typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)> 281 ShapeRepresentationFn; 282 struct Options { 283 // Name of the compilation device to use. It must be set by the caller. 284 // The default empty value is invalid. 285 DeviceType device_type = DeviceType(""); 286 287 // The device to use during compilation to execute instructions on, for 288 // example for auto-tuning. 289 // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. 290 // -1 indicates the default device should be used. 291 int device_ordinal = -1; 292 293 xla::Client* client = nullptr; 294 295 // Function library in which to find function definitions. Must be non-null. 296 const FunctionLibraryDefinition* flib_def = nullptr; 297 298 // The graph def version to be compiled. 299 int graph_def_version = TF_GRAPH_DEF_VERSION; 300 301 // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() 302 // for CPU. 303 bool allow_cpu_custom_calls = false; 304 305 // If set, the XLA representation of variables represented to XLA as the 306 // shape given by this shape function. Variables are reshaped to this shape 307 // on write, and reshaped to their original shape on read. 308 ShapeRepresentationFn shape_representation_fn; 309 310 // If not nullptr, populate_resource_manager is called with the 311 // compilation device's resource manager when the compilation 312 // device is created, and can be used to create metadata objects 313 // that can be accessed by XLA op kernels. 314 std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr; 315 316 // If not nullptr, this memory allocator can be used by the compiler for 317 // temporary allocations it might want to make during compilation. 318 // 319 // For example, the compiler may want to try out different algorithms and 320 // choose the fastest one, and it might run those algorithms over buffers 321 // created using this allocator. 322 // 323 // The compiler can function correctly without an explicit allocator given 324 // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly 325 // allocate most or all available memory on the device, leaving none for the 326 // compiler to access, unless it can use TensorFlow's allocator. 327 xla::DeviceMemoryAllocator* device_allocator = nullptr; 328 }; 329 330 explicit XlaCompiler(Options options); 331 332 ~XlaCompiler(); 333 334 Status CompileFunction(const CompileOptions& options, 335 const NameAttrList& fn_name_attrs, 336 absl::Span<const Argument> args, 337 CompilationResult* result); 338 339 // Compiles a tensorflow::Graph into an xla::XlaComputation. 340 // Similar to CompileFunction, but takes a Graph as input rather than a 341 // function. 342 Status CompileGraph( 343 const CompileOptions& options, string const& name, 344 std::unique_ptr<Graph> graph, absl::Span<const Argument> args, 345 absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases, 346 CompilationResult* result); 347 348 // Compiles a single Op, given by `node_def`, into an 349 // xla::XlaComputation. Similar to CompileFunction but takes a single Op as 350 // input. 351 Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def, 352 absl::Span<const Argument> args, 353 absl::Span<const DataType> result_types, 354 CompilationResult* result); 355 356 // Returns the shape of the XLA parameter for an argument 'arg'. 357 // See the class comment for more details about the argument passing 358 // convention. 359 Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, 360 xla::Shape* xla_shape) const; 361 362 // Retrieves the channel handle associated with `key`. Allocates 363 // a new channel handle if none exists. 364 // Channel handles can be used to communicate between different 365 // computations. Computations that communicate should be compiled with the 366 // same XlaCompiler. 367 Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); 368 369 // Retrieves the host-to-device channel handle associated with `key`. 370 // Allocates a new channel handle if none exists. 371 Status GetHostToDeviceChannelHandle(const string& key, 372 xla::ChannelHandle* channel); 373 374 // Retrieves the device-to-host channel handle associated with `key`. 375 // Allocates a new channel handle if none exists. 376 Status GetDeviceToHostChannelHandle(const string& key, 377 xla::ChannelHandle* channel); 378 379 // Sets the shapes and types for the device to host transfer associated with 380 // 'key'. 381 Status SetDeviceToHostMetadata(const string& key, 382 absl::Span<const DataType> types, 383 absl::Span<const TensorShape> shapes); 384 385 // Gets the shapes the device to host transfer associated with 'key'. 386 Status GetDeviceToHostShapes(const string& key, 387 std::vector<TensorShape>* shapes) const; 388 389 // Sets the shapes and types for the host to device transfer associated with 390 // 'key'. 391 Status SetHostToDeviceMetadata(const string& key, 392 absl::Span<const DataType> types, 393 absl::Span<const TensorShape> shapes); 394 395 // In order to avoid deadlocks from dependencies in host computations, it can 396 // be necessary to enforce a partial order on the execution of HostCompute 397 // Ops. In particular it may be necessary to constrain the SendToHost for one 398 // HostCompute to run before blocking on the RecvAtHost for another 399 // HostCompute. The compiler maintains a mapping from 'host_compute_name' to 400 // handle, where the handle is an 'output' of the HostCompute Op corresponding 401 // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced 402 // later can add the handle as an 'input' to enforce the constraints. 403 // 'host_compute_name' can be any string the client wishes to use to identify 404 // a given HostCompute Op as long as the names are unique within the 405 // compilation. 406 Status GetHostComputeControlDependency(const string& host_compute_name, 407 xla::XlaOp* handle); 408 Status SetHostComputeControlDependency(const string& host_compute_name, 409 const xla::XlaOp& handle); 410 options()411 const Options& options() const { return options_; } client()412 xla::Client* client() const { return options_.client; } flib_runtime()413 FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } 414 415 void PushNodeTokenMapping(); 416 Status PopNodeTokenMapping(); 417 Status SetNodeToken(const string& node_name, const xla::XlaOp& op); 418 xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name); 419 420 // Sets the function body `fbody` to the one registered as `function`. 421 Status FindFunctionBody(const NameAttrList& function, 422 const FunctionBody** fbody); 423 424 private: 425 // Returns the optimized graph object in this function body. 426 std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody); 427 428 // Builds XLA computations for each of the arguments to the computation. 429 // `args` are the arguments to the computation. 430 Status BuildArguments(const Graph& graph, 431 const std::vector<XlaCompiler::Argument>& args, 432 bool use_tuple_arg, xla::XlaBuilder* builder, 433 XlaContext* context, 434 const std::map<int, int>& arg_cores, 435 std::vector<XlaExpression>* arg_expressions, 436 std::vector<int>* input_to_args, 437 std::vector<xla::Shape>* input_shapes, 438 bool is_entry_computation); 439 440 // Graph compiler needs to know how to get an optimized graph from a function 441 // body. 442 friend class GraphCompiler; 443 friend class XlaCompilerTest; 444 445 Options options_; 446 447 // Status set to non-OK in the constructor if initialization fails. 448 Status initialization_status_; 449 450 // Returns the next step sequence number. 451 int64 NextStepId(); 452 453 // Internal sequence number for steps executed on the compilation device. 454 int64 next_step_id_; 455 456 XlaCompilationDevice* device_; // Owned by device_mgr_ 457 DeviceMgr device_mgr_; 458 459 // To avoid copying the client's function library, use a local function 460 // library and runtime for functions created as part of the functionalize 461 // control flow transformation. 462 std::unique_ptr<FunctionLibraryDefinition> local_flib_def_; 463 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 464 std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_; 465 466 FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_. 467 FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. 468 469 struct SignatureHash { 470 uint64 operator()( 471 const std::pair<string, std::vector<Argument>>& signature) const; 472 }; 473 474 std::unordered_map<std::pair<string, std::vector<Argument>>, 475 CompilationResult, SignatureHash> 476 cache_; 477 478 std::unordered_map<string, xla::ChannelHandle> channels_; 479 480 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_; 481 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_; 482 483 std::unordered_map<string, xla::XlaOp> host_compute_control_output_; 484 485 // This is used to store <node name, token output> mapping. Side-effecting 486 // ops call SetNodeToken() to record its token output, so later side-effecting 487 // ops can use GetNodeToken() to get it and use it as token input. 488 // 489 // It's a stack because we need a mapping like this for each level of nested 490 // CompileGraph() call. In CompileGraph(), we will push a new mapping to the 491 // stack, and pop the mapping before returning. 492 std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_; 493 494 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); 495 }; 496 497 } // namespace tensorflow 498 499 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 500