1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 18 19 #include <stack> 20 21 #include "absl/types/span.h" 22 #include "absl/types/variant.h" 23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" 24 #include "tensorflow/compiler/tf2xla/xla_argument.h" 25 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 26 #include "tensorflow/compiler/tf2xla/xla_expression.h" 27 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 29 #include "tensorflow/compiler/xla/client/local_client.h" 30 #include "tensorflow/compiler/xla/client/xla_builder.h" 31 #include "tensorflow/compiler/xla/client/xla_computation.h" 32 #include "tensorflow/compiler/xla/status_macros.h" 33 #include "tensorflow/core/common_runtime/device.h" 34 #include "tensorflow/core/common_runtime/device_mgr.h" 35 #include "tensorflow/core/common_runtime/function.h" 36 #include "tensorflow/core/framework/function.h" 37 #include "tensorflow/core/lib/core/errors.h" 38 #include "tensorflow/core/platform/env.h" 39 #include "tensorflow/core/platform/mutex.h" 40 #include "tensorflow/core/platform/notification.h" 41 #include "tensorflow/core/platform/thread_annotations.h" 42 #include "tensorflow/core/public/version.h" 43 44 namespace tensorflow { 45 46 class XlaContext; 47 48 // The XlaCompiler class is responsible for compilation of a self-contained 49 // subgraph of a TensorFlow computation using the XLA linear algebra runtime. 50 // It does a symbolic execution of the graph starting from specific input 51 // shapes, using a JIT device to convert operators into XLA computations. 52 // 53 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the 54 // shapes of all input parameters to the computation are known. This is 55 // because the symbolic execution requires known shapes for all operations. 56 // 57 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes, 58 // and return outputs via _Retval nodes. 59 // 60 // The XlaCompiler requires one Argument struct for each _Arg index, that 61 // describes each argument. Arguments can be compile-time constants 62 // (kind kConstant), run-time parameters (kind kParameter), or resources 63 // (kind kResource). 64 // 65 // Only kParameter and initialized kResource arguments become runtime parameters 66 // to the generated XLA computation. 67 // 68 // The run-time outputs of the XLA computation are arranged in the following 69 // order: 70 // +------------------+-----------------------------------------+ 71 // | _Retval values | Updated values of kResource arguments | 72 // +------------------+-----------------------------------------+ 73 // _Retval values are ordered by _Retval index, whereas kResource values are 74 // ordered by the original _Arg position of the variable. 75 // 76 // If a shape representation function is provided as part of 77 // XlaCompiler::CompileOptions, kParameter arguments and return values to an 78 // entry computation will be reshaped in accordance to the shape function. 79 // Arguments and return values to a non-entry computation are not reshaped. 80 // Variable resource arguments are passed and returned in reshaped form, even 81 // for non-entry computations. This feature allows TensorFlow to keep on-device 82 // tensors with a different shape to their representation inside the XLA 83 // computation. 84 // 85 // In computation outputs, updated kResource values are placed the end. When 86 // emitting While loop bodies, we must ensure that the loop body has 87 // identical input and output signatures. By passing variable values 88 // at the end of the argument list and using the 89 // `return_updated_values_for_all_variables` option, we can ensure that the 90 // input and output values of resources appear at the same positions. 91 // 92 // Resources are passed as parameters or returned as resource updates in 93 // "packed" form. 94 // kStack resources are packed as (array, size of stack) XLA tuples. 95 // kTensorArray resources without gradients are packed as the array that 96 // backs the TensorArray. If gradients are present (`tensor_array_gradients`), 97 // the packed representation is a (array, gradient0, gradient1, ...) tuple, 98 // where gradient_k is the value of the k-th gradient in the 99 // `tensor_array_gradients` ordered set. 100 class XlaCompiler { 101 public: 102 using Argument = ::tensorflow::XlaArgument; 103 104 // Options pertaining to an individual call to CompileGraph() or 105 // CompileFunction(). 106 struct CompileOptions { 107 // If `use_tuple_arg` is true, a single tuple parameter will be used for all 108 // arguments; if false, each argument gets its own parameter. 109 bool use_tuple_arg = false; 110 111 // If 'return_updated_values_for_all_resources' is true, then updated 112 // values of all resource arguments will be included in the 113 // 'resource_updates' of the computation, even if the resource was not 114 // modified by the computation. Used when compiling loop bodies to ensure 115 // the input and output signatures match. 116 bool return_updated_values_for_all_resources = false; 117 118 // If 'always_return_tuple' is true, then the output of a computation will 119 // always be a tuple. Otherwise, a single-element output will not be wrapped 120 // in a tuple. 121 bool always_return_tuple = true; 122 123 // True when compiling the entry computation, false for subcomputations 124 // (while, call, etc.) 125 bool is_entry_computation = true; 126 127 // True when we should add XLA input & output to the graph/function. 128 bool add_token_input_output = false; 129 130 // Resource updates are converted into input / output of xla. The two 131 // buffers are aliased with other if this option is true. 132 bool alias_resource_update = false; 133 }; 134 135 using OutputDescription = ::tensorflow::XlaOutputDescription; 136 137 using ResourceUpdate = ::tensorflow::XlaResourceUpdate; 138 139 using CompilationResult = ::tensorflow::XlaCompilationResult; 140 141 typedef std::function<StatusOr<xla::Shape>(const TensorShape&, DataType, 142 bool)> 143 ShapeRepresentationFn; 144 struct Options { 145 // Name of the compilation device to use. It must be set by the caller. 146 // The default empty value is invalid. 147 DeviceType device_type = DeviceType(""); 148 149 // The device to use during compilation to execute instructions on, for 150 // example for auto-tuning. 151 // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. 152 // -1 indicates the default device should be used. 153 int device_ordinal = -1; 154 155 xla::Client* client = nullptr; 156 157 // Function library in which to find function definitions. Must be non-null. 158 const FunctionLibraryDefinition* flib_def = nullptr; 159 160 // The graph def version to be compiled. 161 int graph_def_version = TF_GRAPH_DEF_VERSION; 162 163 // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() 164 // for CPU. 165 bool allow_cpu_custom_calls = false; 166 167 // If both this and 'allow_cpu_custom_calls' are true then tf.fake_quant_* 168 // ops will be emitted as custom calls to a 'fake_quant_with_min_max_vars' 169 // function accepting the input, min, max, num_bits, and narrow_range values 170 // as runtime arguments. 171 bool custom_fake_quant_op_calls = false; 172 173 // If set, the XLA representation of variables represented to XLA as the 174 // shape given by this shape function. Variables are reshaped to this shape 175 // on write, and reshaped to their original shape on read. 176 ShapeRepresentationFn shape_representation_fn; 177 178 // If not nullptr, populate_resource_manager is called with the 179 // compilation device's resource manager when the compilation 180 // device is created, and can be used to create metadata objects 181 // that can be accessed by XLA op kernels. 182 std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr; 183 184 // If not nullptr, this memory allocator can be used by the compiler for 185 // temporary allocations it might want to make during compilation. 186 // 187 // For example, the compiler may want to try out different algorithms and 188 // choose the fastest one, and it might run those algorithms over buffers 189 // created using this allocator. 190 // 191 // The compiler can function correctly without an explicit allocator given 192 // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly 193 // allocate most or all available memory on the device, leaving none for the 194 // compiler to access, unless it can use TensorFlow's allocator. 195 // This must be a shared_ptr, as this is passed all the way down to the 196 // cluster compilation. This allows asynchronous compilation to hold a 197 // reference until the compilation is finished. 198 std::shared_ptr<se::DeviceMemoryAllocator> device_allocator; 199 200 // Alias input and output buffers for parameters that are passed-through XLA 201 // modules without being changed. 202 bool alias_passthrough_params = false; 203 204 // Enable detailed logging of compilation metadata. 205 bool detailed_logging = true; 206 }; 207 208 explicit XlaCompiler(Options options); 209 210 ~XlaCompiler(); 211 212 // Helper function to populate an XlaCompiler::Argument from XlaResource. 213 static void PopulateArgumentFromResource(const XlaResource& resource, 214 Argument* arg); 215 216 Status CompileFunction(const CompileOptions& options, 217 const NameAttrList& fn_name_attrs, 218 absl::Span<const Argument> args, 219 CompilationResult* result); 220 221 // Compiles a tensorflow::Graph into an xla::XlaComputation. 222 // Similar to CompileFunction, but takes a Graph as input rather than a 223 // function. 224 Status CompileGraph( 225 const CompileOptions& options, string const& name, 226 std::unique_ptr<Graph> graph, absl::Span<const Argument> args, 227 CompilationResult* result); 228 229 // Returns the shape of the XLA parameter for an argument 'arg'. 230 // See the class comment for more details about the argument passing 231 // convention. 232 Status XLAShapeForArgument( 233 const Argument& arg, bool is_entry_computation, 234 const absl::optional<xla::HloSharding>& arg_sharding, 235 xla::Shape* xla_shape) const; 236 237 // Retrieves the channel handle associated with `key`. Allocates 238 // a new channel handle if none exists. 239 // Channel handles can be used to communicate between different 240 // computations. Computations that communicate should be compiled with the 241 // same XlaCompiler. 242 Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); 243 244 // Retrieves the host-to-device channel handle associated with `key`. 245 // Allocates a new channel handle if none exists. 246 Status GetHostToDeviceChannelHandle(const string& key, 247 xla::ChannelHandle* channel); 248 249 // Retrieves the device-to-host channel handle associated with `key`. 250 // Allocates a new channel handle if none exists. 251 Status GetDeviceToHostChannelHandle(const string& key, 252 xla::ChannelHandle* channel); 253 254 // Sets the shapes and types for the device to host transfer associated with 255 // 'key'. 256 Status SetDeviceToHostMetadata(const string& key, 257 absl::Span<const DataType> types, 258 absl::Span<const TensorShape> shapes); 259 260 // Gets the shapes the device to host transfer associated with 'key'. 261 Status GetDeviceToHostShapes(const string& key, 262 std::vector<TensorShape>* shapes) const; 263 264 // Sets the shapes and types for the host to device transfer associated with 265 // 'key'. 266 Status SetHostToDeviceMetadata(const string& key, 267 absl::Span<const DataType> types, 268 absl::Span<const TensorShape> shapes); 269 270 // In order to avoid deadlocks from dependencies in host computations, it can 271 // be necessary to enforce a partial order on the execution of HostCompute 272 // Ops. In particular it may be necessary to constrain the SendToHost for one 273 // HostCompute to run before blocking on the RecvAtHost for another 274 // HostCompute. The compiler maintains a mapping from 'host_compute_name' to 275 // handle, where the handle is an 'output' of the HostCompute Op corresponding 276 // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced 277 // later can add the handle as an 'input' to enforce the constraints. 278 // 'host_compute_name' can be any string the client wishes to use to identify 279 // a given HostCompute Op as long as the names are unique within the 280 // compilation. 281 Status GetHostComputeControlDependency(const string& host_compute_name, 282 xla::XlaOp* handle); 283 Status SetHostComputeControlDependency(const string& host_compute_name, 284 const xla::XlaOp& handle); 285 options()286 const Options& options() const { return options_; } client()287 xla::Client* client() const { return options_.client; } flib_runtime()288 FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } 289 290 void PushNodeTokenMapping(); 291 Status PopNodeTokenMapping(); 292 Status SetNodeToken(const string& node_name, const xla::XlaOp& op); 293 StatusOr<xla::XlaOp> GetNodeToken(const string& node_name); 294 295 // Sets the function body `fbody` to the one registered as `function`. 296 Status FindFunctionBody(const NameAttrList& function, 297 const FunctionBody** fbody, 298 const ConfigProto** config_proto = nullptr); 299 300 private: 301 // Returns the optimized graph object in this function body. 302 std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody); 303 304 // Builds XLA computations for each of the arguments to the computation. 305 // `args` are the arguments to the computation. 306 Status BuildArguments(const Graph& graph, 307 const std::vector<XlaCompiler::Argument>& args, 308 bool use_tuple_arg, xla::XlaBuilder* builder, 309 XlaContext* context, 310 const std::map<int, xla::OpSharding>& arg_shardings, 311 std::vector<XlaExpression>* arg_expressions, 312 std::vector<int>* input_to_args, 313 std::vector<xla::Shape>* input_shapes, 314 bool is_entry_computation); 315 316 // Graph compiler needs to know how to get an optimized graph from a function 317 // body. 318 friend class GraphCompiler; 319 friend class XlaCompilerTest; 320 321 Options options_; 322 323 // Status set to non-OK in the constructor if initialization fails. 324 Status initialization_status_; 325 326 // Returns the next step sequence number. 327 int64 NextStepId(); 328 329 // Internal sequence number for steps executed on the compilation device. 330 int64 next_step_id_; 331 332 XlaCompilationDevice* device_; // Owned by device_mgr_ 333 StaticDeviceMgr device_mgr_; 334 335 // To avoid copying the client's function library, use a local function 336 // library and runtime for functions created as part of the functionalize 337 // control flow transformation. 338 std::unique_ptr<FunctionLibraryDefinition> local_flib_def_; 339 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 340 std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_; 341 342 FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_. 343 FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. 344 345 struct SignatureHash { 346 uint64 operator()( 347 const std::pair<string, std::vector<Argument>>& signature) const; 348 }; 349 350 std::unordered_map<std::pair<string, std::vector<Argument>>, 351 CompilationResult, SignatureHash> 352 cache_; 353 354 std::unordered_map<string, xla::ChannelHandle> channels_; 355 356 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_; 357 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_; 358 359 std::unordered_map<string, xla::XlaOp> host_compute_control_output_; 360 361 // This is used to store <node name, token output> mapping. Side-effecting 362 // ops call SetNodeToken() to record its token output, so later side-effecting 363 // ops can use GetNodeToken() to get it and use it as token input. 364 // 365 // It's a stack because we need a mapping like this for each level of nested 366 // CompileGraph() call. In CompileGraph(), we will push a new mapping to the 367 // stack, and pop the mapping before returning. 368 std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_; 369 370 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); 371 }; 372 373 374 } // namespace tensorflow 375 376 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 377