1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 18 19 #include <stack> 20 21 #include "absl/types/span.h" 22 #include "absl/types/variant.h" 23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" 24 #include "tensorflow/compiler/tf2xla/xla_argument.h" 25 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 26 #include "tensorflow/compiler/tf2xla/xla_expression.h" 27 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 29 #include "tensorflow/compiler/xla/client/local_client.h" 30 #include "tensorflow/compiler/xla/client/xla_builder.h" 31 #include "tensorflow/compiler/xla/client/xla_computation.h" 32 #include "tensorflow/compiler/xla/status_macros.h" 33 #include "tensorflow/core/common_runtime/device.h" 34 #include "tensorflow/core/common_runtime/device_mgr.h" 35 #include "tensorflow/core/common_runtime/function.h" 36 #include "tensorflow/core/framework/function.h" 37 #include "tensorflow/core/lib/core/errors.h" 38 #include "tensorflow/core/platform/env.h" 39 #include "tensorflow/core/platform/mutex.h" 40 #include "tensorflow/core/platform/notification.h" 41 #include "tensorflow/core/platform/thread_annotations.h" 42 #include "tensorflow/core/public/version.h" 43 44 namespace tensorflow { 45 46 class XlaContext; 47 48 // The XlaCompiler class is responsible for compilation of a self-contained 49 // subgraph of a TensorFlow computation using the XLA linear algebra runtime. 50 // It does a symbolic execution of the graph starting from specific input 51 // shapes, using a JIT device to convert operators into XLA computations. 52 // 53 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the 54 // shapes of all input parameters to the computation are known. This is 55 // because the symbolic execution requires known shapes for all operations. 56 // 57 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes, 58 // and return outputs via _Retval nodes. 59 // 60 // The XlaCompiler requires one Argument struct for each _Arg index, that 61 // describes each argument. Arguments can be compile-time constants 62 // (kind kConstant), run-time parameters (kind kParameter), or resources 63 // (kind kResource). 64 // 65 // Only kParameter and initialized kResource arguments become runtime parameters 66 // to the generated XLA computation. 67 // 68 // The run-time outputs of the XLA computation are arranged in the following 69 // order: 70 // +------------------+-----------------------------------------+ 71 // | _Retval values | Updated values of kResource arguments | 72 // +------------------+-----------------------------------------+ 73 // _Retval values are ordered by _Retval index, whereas kResource values are 74 // ordered by the original _Arg position of the variable. 75 // 76 // If a shape representation function is provided as part of 77 // XlaCompiler::CompileOptions, kParameter arguments and return values to an 78 // entry computation will be reshaped in accordance to the shape function. 79 // Arguments and return values to a non-entry computation are not reshaped. 80 // Variable resource arguments are passed and returned in reshaped form, even 81 // for non-entry computations. This feature allows TensorFlow to keep on-device 82 // tensors with a different shape to their representation inside the XLA 83 // computation. 84 // 85 // In computation outputs, updated kResource values are placed the end. When 86 // emitting While loop bodies, we must ensure that the loop body has 87 // identical input and output signatures. By passing variable values 88 // at the end of the argument list and using the 89 // `return_updated_values_for_all_variables` option, we can ensure that the 90 // input and output values of resources appear at the same positions. 91 // 92 // Resources are passed as parameters or returned as resource updates in 93 // "packed" form. 94 // kStack resources are packed as (array, size of stack) XLA tuples. 95 // kTensorArray resources without gradients are packed as the array that 96 // backs the TensorArray. If gradients are present (`tensor_array_gradients`), 97 // the packed representation is a (array, gradient0, gradient1, ...) tuple, 98 // where gradient_k is the value of the k-th gradient in the 99 // `tensor_array_gradients` ordered set. 100 class XlaCompiler { 101 public: 102 using Argument = ::tensorflow::XlaArgument; 103 104 // Options pertaining to an individual call to CompileGraph() or 105 // CompileFunction(). 106 struct CompileOptions { 107 // If `use_tuple_arg` is true, a single tuple parameter will be used for all 108 // arguments; if false, each argument gets its own parameter. 109 bool use_tuple_arg = false; 110 111 // If 'return_updated_values_for_all_resources' is true, then updated 112 // values of all resource arguments will be included in the 113 // 'resource_updates' of the computation, even if the resource was not 114 // modified by the computation. Used when compiling loop bodies to ensure 115 // the input and output signatures match. 116 bool return_updated_values_for_all_resources = false; 117 118 // If 'always_return_tuple' is true, then the output of a computation will 119 // always be a tuple. Otherwise, a single-element output will not be wrapped 120 // in a tuple. 121 bool always_return_tuple = true; 122 123 // True when compiling the entry computation, false for subcomputations 124 // (while, call, etc.) 125 bool is_entry_computation = true; 126 127 // True when we should add XLA input & output to the graph/function. 128 bool add_token_input_output = false; 129 130 // Resource updates are converted into input / output of xla. The two 131 // buffers are aliased with other if this option is true. 132 bool alias_resource_update = false; 133 }; 134 135 using OutputDescription = ::tensorflow::XlaOutputDescription; 136 137 using ResourceUpdate = ::tensorflow::XlaResourceUpdate; 138 139 using CompilationResult = ::tensorflow::XlaCompilationResult; 140 141 typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType, 142 bool)> 143 ShapeRepresentationFn; 144 struct Options { 145 // Name of the compilation device to use. It must be set by the caller. 146 // The default empty value is invalid. 147 DeviceType device_type = DeviceType(""); 148 149 // The device to use during compilation to execute instructions on, for 150 // example for auto-tuning. 151 // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. 152 // -1 indicates the default device should be used. 153 int device_ordinal = -1; 154 155 xla::Client* client = nullptr; 156 157 // Function library in which to find function definitions. Must be non-null. 158 const FunctionLibraryDefinition* flib_def = nullptr; 159 160 // The graph def version to be compiled. 161 int graph_def_version = TF_GRAPH_DEF_VERSION; 162 163 // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() 164 // for CPU. 165 bool allow_cpu_custom_calls = false; 166 167 // If both this and 'allow_cpu_custom_calls' are true then tf.fake_quant_* 168 // ops will be emitted as custom calls to a 'fake_quant_with_min_max_vars' 169 // function accepting the input, min, max, num_bits, and narrow_range values 170 // as runtime arguments. 171 bool custom_fake_quant_op_calls = false; 172 173 // If set, the XLA representation of variables represented to XLA as the 174 // shape given by this shape function. Variables are reshaped to this shape 175 // on write, and reshaped to their original shape on read. 176 ShapeRepresentationFn shape_representation_fn; 177 178 // If not nullptr, populate_resource_manager is called with the 179 // compilation device's resource manager when the compilation 180 // device is created, and can be used to create metadata objects 181 // that can be accessed by XLA op kernels. 182 std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr; 183 184 // If not nullptr, this memory allocator can be used by the compiler for 185 // temporary allocations it might want to make during compilation. 186 // 187 // For example, the compiler may want to try out different algorithms and 188 // choose the fastest one, and it might run those algorithms over buffers 189 // created using this allocator. 190 // 191 // The compiler can function correctly without an explicit allocator given 192 // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly 193 // allocate most or all available memory on the device, leaving none for the 194 // compiler to access, unless it can use TensorFlow's allocator. 195 se::DeviceMemoryAllocator* device_allocator = nullptr; 196 197 // Alias input and output buffers for parameters that are passed-through XLA 198 // modules without being changed. 199 bool alias_passthrough_params = false; 200 201 // Enable detailed logging of compilation metadata. 202 bool detailed_logging = true; 203 }; 204 205 explicit XlaCompiler(Options options); 206 207 ~XlaCompiler(); 208 209 // Helper function to populate an XlaCompiler::Argument from XlaResource. 210 static void PopulateArgumentFromResource(const XlaResource& resource, 211 Argument* arg); 212 213 Status CompileFunction(const CompileOptions& options, 214 const NameAttrList& fn_name_attrs, 215 absl::Span<const Argument> args, 216 CompilationResult* result); 217 218 // Compiles a tensorflow::Graph into an xla::XlaComputation. 219 // Similar to CompileFunction, but takes a Graph as input rather than a 220 // function. 221 Status CompileGraph( 222 const CompileOptions& options, string const& name, 223 std::unique_ptr<Graph> graph, absl::Span<const Argument> args, 224 CompilationResult* result); 225 226 // Returns the shape of the XLA parameter for an argument 'arg'. 227 // See the class comment for more details about the argument passing 228 // convention. 229 Status XLAShapeForArgument( 230 const Argument& arg, bool is_entry_computation, 231 const absl::optional<xla::HloSharding>& arg_sharding, 232 xla::Shape* xla_shape) const; 233 234 // Retrieves the channel handle associated with `key`. Allocates 235 // a new channel handle if none exists. 236 // Channel handles can be used to communicate between different 237 // computations. Computations that communicate should be compiled with the 238 // same XlaCompiler. 239 Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); 240 241 // Retrieves the host-to-device channel handle associated with `key`. 242 // Allocates a new channel handle if none exists. 243 Status GetHostToDeviceChannelHandle(const string& key, 244 xla::ChannelHandle* channel); 245 246 // Retrieves the device-to-host channel handle associated with `key`. 247 // Allocates a new channel handle if none exists. 248 Status GetDeviceToHostChannelHandle(const string& key, 249 xla::ChannelHandle* channel); 250 251 // Sets the shapes and types for the device to host transfer associated with 252 // 'key'. 253 Status SetDeviceToHostMetadata(const string& key, 254 absl::Span<const DataType> types, 255 absl::Span<const TensorShape> shapes); 256 257 // Gets the shapes the device to host transfer associated with 'key'. 258 Status GetDeviceToHostShapes(const string& key, 259 std::vector<TensorShape>* shapes) const; 260 261 // Sets the shapes and types for the host to device transfer associated with 262 // 'key'. 263 Status SetHostToDeviceMetadata(const string& key, 264 absl::Span<const DataType> types, 265 absl::Span<const TensorShape> shapes); 266 267 // In order to avoid deadlocks from dependencies in host computations, it can 268 // be necessary to enforce a partial order on the execution of HostCompute 269 // Ops. In particular it may be necessary to constrain the SendToHost for one 270 // HostCompute to run before blocking on the RecvAtHost for another 271 // HostCompute. The compiler maintains a mapping from 'host_compute_name' to 272 // handle, where the handle is an 'output' of the HostCompute Op corresponding 273 // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced 274 // later can add the handle as an 'input' to enforce the constraints. 275 // 'host_compute_name' can be any string the client wishes to use to identify 276 // a given HostCompute Op as long as the names are unique within the 277 // compilation. 278 Status GetHostComputeControlDependency(const string& host_compute_name, 279 xla::XlaOp* handle); 280 Status SetHostComputeControlDependency(const string& host_compute_name, 281 const xla::XlaOp& handle); 282 options()283 const Options& options() const { return options_; } client()284 xla::Client* client() const { return options_.client; } flib_runtime()285 FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } 286 287 void PushNodeTokenMapping(); 288 Status PopNodeTokenMapping(); 289 Status SetNodeToken(const string& node_name, const xla::XlaOp& op); 290 xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name); 291 292 // Sets the function body `fbody` to the one registered as `function`. 293 Status FindFunctionBody(const NameAttrList& function, 294 const FunctionBody** fbody, 295 const ConfigProto** config_proto = nullptr); 296 297 private: 298 // Returns the optimized graph object in this function body. 299 std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody); 300 301 // Builds XLA computations for each of the arguments to the computation. 302 // `args` are the arguments to the computation. 303 Status BuildArguments(const Graph& graph, 304 const std::vector<XlaCompiler::Argument>& args, 305 bool use_tuple_arg, xla::XlaBuilder* builder, 306 XlaContext* context, 307 const std::map<int, xla::OpSharding>& arg_shardings, 308 std::vector<XlaExpression>* arg_expressions, 309 std::vector<int>* input_to_args, 310 std::vector<xla::Shape>* input_shapes, 311 bool is_entry_computation); 312 313 // Graph compiler needs to know how to get an optimized graph from a function 314 // body. 315 friend class GraphCompiler; 316 friend class XlaCompilerTest; 317 318 Options options_; 319 320 // Status set to non-OK in the constructor if initialization fails. 321 Status initialization_status_; 322 323 // Returns the next step sequence number. 324 int64 NextStepId(); 325 326 // Internal sequence number for steps executed on the compilation device. 327 int64 next_step_id_; 328 329 XlaCompilationDevice* device_; // Owned by device_mgr_ 330 StaticDeviceMgr device_mgr_; 331 332 // To avoid copying the client's function library, use a local function 333 // library and runtime for functions created as part of the functionalize 334 // control flow transformation. 335 std::unique_ptr<FunctionLibraryDefinition> local_flib_def_; 336 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 337 std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_; 338 339 FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_. 340 FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. 341 342 struct SignatureHash { 343 uint64 operator()( 344 const std::pair<string, std::vector<Argument>>& signature) const; 345 }; 346 347 std::unordered_map<std::pair<string, std::vector<Argument>>, 348 CompilationResult, SignatureHash> 349 cache_; 350 351 std::unordered_map<string, xla::ChannelHandle> channels_; 352 353 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_; 354 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_; 355 356 std::unordered_map<string, xla::XlaOp> host_compute_control_output_; 357 358 // This is used to store <node name, token output> mapping. Side-effecting 359 // ops call SetNodeToken() to record its token output, so later side-effecting 360 // ops can use GetNodeToken() to get it and use it as token input. 361 // 362 // It's a stack because we need a mapping like this for each level of nested 363 // CompileGraph() call. In CompileGraph(), we will push a new mapping to the 364 // stack, and pop the mapping before returning. 365 std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_; 366 367 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); 368 }; 369 370 371 } // namespace tensorflow 372 373 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 374