1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 17 18 #include <unordered_map> 19 20 // clang-format off 21 // Required for IS_MOBILE_PLATFORM 22 #include "tensorflow/core/platform/platform.h" 23 // clang-format on 24 25 #include "absl/types/optional.h" 26 #include "absl/types/variant.h" 27 #include "tensorflow/core/common_runtime/composite_device.h" 28 #include "tensorflow/core/common_runtime/device_mgr.h" 29 #include "tensorflow/core/common_runtime/device_set.h" 30 #include "tensorflow/core/framework/function.h" 31 #include "tensorflow/core/framework/types.h" 32 #include "tensorflow/core/lib/core/status.h" 33 #include "tensorflow/core/protobuf/config.pb.h" 34 #if !defined(IS_MOBILE_PLATFORM) 35 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 36 #endif // IS_MOBILE_PLATFORM 37 38 namespace tensorflow { 39 40 class FunctionArgsInterface { 41 public: ~FunctionArgsInterface()42 virtual ~FunctionArgsInterface() {} 43 44 virtual bool HasRemoteOrPackedInputs() const = 0; 45 46 virtual Status GetLocalArg(const FunctionArgIndex& index, 47 Tensor* val) const = 0; 48 49 virtual std::vector<Tensor> GetLocalTensors() const = 0; 50 51 #if !defined(IS_MOBILE_PLATFORM) GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val)52 virtual Status GetRemoteArg(const FunctionArgIndex& index, 53 eager::RemoteTensorHandle* val) const { 54 return errors::Unimplemented( 55 "Serializing a remote argument is not implemented."); 56 } 57 #endif // IS_MOBILE_PLATFORM 58 }; 59 60 // A class that stores all the FunctionLibraryRuntime objects, one per device. 61 class ProcessFunctionLibraryRuntime { 62 public: 63 // Creates FunctionLibraryRuntime objects for each device in the provided 64 // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent 65 // (if provided) outlive this object. 66 ProcessFunctionLibraryRuntime( 67 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, 68 int graph_def_version, const FunctionLibraryDefinition* lib_def, 69 const OptimizerOptions& optimizer_options, 70 thread::ThreadPool* thread_pool = nullptr, 71 DistributedFunctionLibraryRuntime* parent = nullptr, 72 const SessionMetadata* session_metadata = nullptr, 73 Rendezvous::Factory rendezvous_factory = Rendezvous::Factory()); 74 ~ProcessFunctionLibraryRuntime()75 ~ProcessFunctionLibraryRuntime() { 76 // Deleting the FunctionLibraryRuntime map will delete the function handles 77 // registered in it, which may call ReleaseHandle in this class again to 78 // release their sub-function. These circular calls may cause segfault 79 // since the flr_map_ may have already been deleted. Explicitly releasing 80 // flr_map_ here and checking flr_map_ in ReleaseHandle to avoid this. 81 flr_map_.reset(); 82 } 83 84 // Sends `tensors_to_send` from `source_device` to `target_device` using 85 // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the 86 // Rendezvous. `device_context` should be the DeviceContext of the device 87 // doing the sending. `alloc_attrs` should either be empty or be the size of 88 // `tensors_to_send` and indicates how the input tensors are allocated. Method 89 // takes references on each of the `tensors_to_send`. Method doesn't block. 90 static Status SendTensors(const string& source_device, 91 const string& target_device, 92 const string& key_prefix, int64_t src_incarnation, 93 gtl::ArraySlice<Tensor> tensors_to_send, 94 DeviceContext* device_context, 95 const std::vector<AllocatorAttributes>& alloc_attrs, 96 RendezvousInterface* rendezvous); 97 98 // Receives `received_tensors` from `target_device` (originally sent from 99 // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the 100 // keys to be retrieved. `device_context` should be for the device receiving 101 // the tensors. `alloc_attrs` indicates how to allocate the received 102 // tensors and should either be empty or `num_tensors` in size. Method doesn't 103 // block and calls `done` when `num_tensors` are fetched. 104 static void ReceiveTensorsAsync( 105 const string& source_device, const string& target_device, 106 const string& key_prefix, int64_t src_incarnation, int64_t num_tensors, 107 DeviceContext* device_context, 108 const std::vector<AllocatorAttributes>& alloc_attrs, 109 RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors, 110 StatusCallback done); 111 112 static const char kDefaultFLRDevice[]; 113 // Returns the FunctionLibraryRuntime for the corresponding device_name. 114 FunctionLibraryRuntime* GetFLR(const string& device_name) const; 115 116 // Returns the return types for the function identified by handle `h`. 117 Status GetRetTypes(FunctionLibraryRuntime::Handle h, 118 DataTypeVector* ret_types); 119 120 // Returns the device incarnation for the given device_name. 121 Status GetDeviceIncarnation(const string& device_name, 122 int64* incarnation) const; 123 124 // For a given canonicalized key signature of the function instantiated 125 // on device `device_name` and a `local_handle`, creates a handle and returns 126 // that value. Uses core/common_runtime/framework/function.h::Canonicalize 127 // to canonicalize the function signature. 128 FunctionLibraryRuntime::Handle AddHandle( 129 const string& function_key, const string& device_name, 130 FunctionLibraryRuntime::LocalHandle local_handle); 131 132 // Returns a handle if found for the given key, else returns kInvalidHandle. 133 FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; 134 135 // For the given handle instantiated on device `device_name` returns the local 136 // index of instantiation of that function. If the function was not 137 // instantiated on `device_name` or the function is multi-device, 138 // returns kInvalidLocalHandle. 139 // 140 // If `include_multi_device` is true and `handle` is a multi-device function 141 // with a single component that is placed on `device_name`, then this method 142 // will return the local handle for that component. 143 FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( 144 const string& device_name, FunctionLibraryRuntime::Handle handle, 145 bool include_multi_device = false) const; 146 147 // Fills `output_devices` with the devices on which the results will 148 // be produced. If some output is produced on CPU, the corresponding Device* 149 // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device* 150 // is set to the device backing the resource. 151 // REQUIRES: `handle` identifies a multi-device function. 152 Status GetOutputDevices(FunctionLibraryRuntime::Handle handle, 153 std::vector<Device*>* output_devices) const; 154 155 // Returns true if function with handle `handle` was instantiated on device 156 // `device_name`. Returns false for multi-device functions. 157 bool IsInstantiatedOnDevice(const string& device_name, 158 FunctionLibraryRuntime::Handle handle) const; 159 160 // Instantiates the function. See framework/function.h for more details. 161 // Allows for function_name to be instantiated on different devices 162 // as specified in attrs. 163 Status Instantiate(const string& function_name, AttrSlice attrs, 164 const FunctionLibraryRuntime::InstantiateOptions& options, 165 FunctionLibraryRuntime::Handle* handle); 166 167 // Returns whether the function represented by the given handle needs to 168 // execute cross process. 169 Status IsCrossProcess(FunctionLibraryRuntime::Handle handle, 170 bool* is_cross_process) const; 171 172 // TODO(iga): Reword 173 // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the 174 // corresponding resource lives. This ensures that the Placer assigns ops that 175 // access these resources to the appropriate devices. 176 static Status PinArgsAndRets(const std::vector<string>& input_devices, 177 const std::vector<string>& output_devices, 178 const DeviceSet& device_set, 179 const std::vector<Node*>& arg_nodes, 180 const std::vector<Node*>& ret_nodes, 181 const FunctionLibraryDefinition* lib_def, 182 Device* default_device); 183 184 // Delegates to the local FLR that owns state corresponding to `handle` and 185 // tells it to release it. If the `handle` isn't needed at all, the local FLR 186 // might call RemoveHandle on this to get rid of the state owned by the Proc 187 // FLR. 188 // For multi-device functions, calls ReleaseHandle on local FLRs for each 189 // component function that is part of this multi-device function. 190 // Each local FLR might call RemoveHandle on this. 191 Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); 192 193 // Runs the function with given `handle`. Function could have been 194 // instantiated on any device. More details in framework/function.h 195 void Run(const FunctionLibraryRuntime::Options& opts, 196 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, 197 std::vector<Tensor>* rets, 198 FunctionLibraryRuntime::DoneCallback done) const; 199 void Run(const FunctionLibraryRuntime::Options& opts, 200 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame, 201 FunctionLibraryRuntime::DoneCallback done) const; 202 203 void Run(const FunctionLibraryRuntime::Options& opts, 204 FunctionLibraryRuntime::Handle handle, 205 const FunctionArgsInterface& args, std::vector<FunctionRet>* rets, 206 FunctionLibraryRuntime::DoneCallback done) const; 207 208 Status RunSync(const FunctionLibraryRuntime::Options& opts, 209 FunctionLibraryRuntime::Handle handle, 210 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const; 211 Status RunSync(const FunctionLibraryRuntime::Options& opts, 212 FunctionLibraryRuntime::Handle handle, 213 CallFrameInterface* frame) const; 214 device_mgr()215 const DeviceMgr* device_mgr() { return device_mgr_; } 216 device_set()217 const std::shared_ptr<DeviceSet> device_set() const { 218 tf_shared_lock l(mu_); 219 return device_set_; 220 } 221 222 // Initialize the set of local and remote devices and corresponding flr for op 223 // device selection. 224 void InitializeDeviceAndFlr(); 225 config()226 const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; } 227 GetFunctionLibraryDefinition()228 const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const { 229 return lib_def_; 230 } 231 232 // Add a CompositeDevice to `device_set_` AddCompositeDevice(CompositeDevice * d)233 void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) { 234 mutex_lock l(mu_); 235 device_set_->AddDevice(d); 236 composite_devices_.push_back(d); 237 } 238 239 protected: 240 friend class FunctionLibraryRuntimeImpl; 241 242 struct InternalArgs { 243 std::vector<FunctionArg> args; 244 #if !defined(IS_MOBILE_PLATFORM) 245 // Holds the RemoteTensorHandles referred by args. 246 std::vector<std::unique_ptr<eager::RemoteTensorHandle>> remote_args; 247 #endif // IS_MOBILE_PLATFORM 248 }; 249 250 // Structure to keep track of how a component function (a single-device 251 // piece of a multi-device function) fits into the multi-device function. 252 struct ComponentFunctionData { 253 // The handle for the instantiated component function. 254 FunctionLibraryRuntime::Handle handle; 255 // arg_indices.size() is the number of arguments to the component function. 256 // The i-th argument of the component function comes from the 257 // `arg_indices[i]`-th argument of the multi-device function. 258 std::vector<FunctionArgIndex> arg_indices; 259 // ret_indices.size() is the number of return values of the component 260 // function. The i-th return value of the component function goes to the 261 // `ret_indices[i]`-th return value of the multi-device function. 262 std::vector<int> ret_indices; 263 // arg_alloc_attrs[i] are the allocator attributes of the i-th argument to 264 // the component function. 265 std::vector<AllocatorAttributes> arg_alloc_attrs; 266 // ret_alloc_attrs[i] are the allocator attributes of the i-th return value 267 // of the component function. 268 std::vector<AllocatorAttributes> ret_alloc_attrs; 269 }; 270 271 // Data structure holding information for a single instantiated multi-device 272 // function. 273 // The fields are filled in during instantiation. Once the object is 274 // added to mdevice_data_, all fields are constant. 275 struct MultiDeviceFunctionData { MultiDeviceFunctionDataMultiDeviceFunctionData276 MultiDeviceFunctionData(const string& function_name, 277 const string& function_key, int num_outputs, 278 FunctionLibraryDefinition&& lib_def, 279 DataTypeVector ret_types) 280 : function_name_(function_name), 281 function_key_(function_key), 282 instantiation_counter_(1), 283 lib_def_(std::move(lib_def)), 284 num_outputs_(num_outputs), 285 ret_types_(std::move(ret_types)), 286 is_cross_process_(false), 287 has_remote_outputs(false) {} 288 289 const string function_name_; 290 const string function_key_; 291 uint64 instantiation_counter_; 292 // A library that contains definitions of component functions and their 293 // transitive dependencies. 294 FunctionLibraryDefinition lib_def_; 295 // Stored here to resize the output tensor vector when function is run. 296 const int num_outputs_; 297 DataTypeVector ret_types_; 298 299 // Indicates whether this function needs to execute cross process. 300 bool is_cross_process_; 301 // Indicates whether this function has remote outputs. 302 bool has_remote_outputs; 303 304 // Maps the device name to the information about the component function 305 // be run on this device. 306 std::unordered_map<string, ComponentFunctionData> glue_; 307 }; 308 309 struct CleanUpItem { 310 string device; 311 uint64 step_id; 312 FunctionLibraryRuntime::Handle local_handle; 313 }; 314 315 // If `handle` represents a multi-device function, returns the multi-device 316 // data associated with `handle`. Else, nullptr. 317 MultiDeviceFunctionData* IsMultiDevice( 318 FunctionLibraryRuntime::Handle handle) const; 319 320 void RunMultiDevice( 321 const FunctionLibraryRuntime::Options& opts, 322 FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets, 323 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items, 324 FunctionLibraryRuntime::DoneCallback done, 325 std::function<Status(const ComponentFunctionData& comp_data, 326 InternalArgs* args)> 327 get_component_args) const; 328 329 Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts, 330 Rendezvous** created_rendezvous) const; 331 332 FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback( 333 std::vector<std::unique_ptr<CleanUpItem>>* items, 334 FunctionLibraryRuntime::DoneCallback done, const int64_t step_id, 335 const Rendezvous* rendezvous) const; 336 337 DistributedFunctionLibraryRuntime* const parent_; 338 339 private: 340 FunctionLibraryRuntime::Handle AddHandleLocked( 341 const string& function_key, const string& device_name, 342 FunctionLibraryRuntime::LocalHandle local_handle) 343 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 344 345 // For a given device_name, returns a DeviceContext for copying 346 // tensors to/from the device. 347 Status GetDeviceContext(const string& device_name, 348 DeviceContext** device_context) const; 349 350 // Looks up the information for the given `handle` and returns the name 351 // of the device where the function is registered. 352 string GetDeviceName(FunctionLibraryRuntime::Handle handle) const; 353 354 // Removes handle from the state owned by this object. 355 Status RemoveHandle(FunctionLibraryRuntime::Handle handle); 356 357 // Clones ProcessFunctionLibraryRuntime and FunctionLibraryDefinition 358 // (transferring ownership of both to the caller). Note that the 359 // ProcessFunctionLibraryRuntime borrows a pointer to the 360 // FunctionLibraryDefinition and so the FunctionLibraryDefinition should 361 // outlive the ProcessFunctionLibraryRuntime. 362 // 363 // The `skip_flib_def` argument controls whether the method should clone the 364 // FunctionLibraryDefinition (default behavior) or return an empty function 365 // library. The latter is used by tf.data, which manages 366 // FunctionLibraryDefinitions for its functions independently (and passes 367 // these into the FunctionLibraryRuntime through an overlay), to avoid linear 368 // runtime w.r.t. to number of functions in the current function library. 369 Status Clone(Env* env, int graph_def_version, 370 const OptimizerOptions& optimizer_options, 371 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 372 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, 373 bool skip_flib_def = false) const; 374 375 Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle); 376 377 Status InstantiateMultiDevice( 378 const string& function_name, AttrSlice attrs, 379 const FunctionLibraryRuntime::InstantiateOptions& options, 380 FunctionLibraryRuntime::Handle* handle); 381 382 void InstantiateRemote( 383 const string& function_name, AttrSlice attrs, 384 const FunctionLibraryRuntime::InstantiateOptions& options, 385 FunctionLibraryRuntime::Handle* handle, 386 FunctionLibraryRuntime::DoneCallback done); 387 388 FunctionLibraryRuntime::Handle AddMultiDeviceHandle( 389 const std::unique_ptr<MultiDeviceFunctionData> data, 390 const string& function_key); 391 392 void RunInternal(const FunctionLibraryRuntime::Options& opts, 393 FunctionLibraryRuntime::Handle handle, 394 gtl::ArraySlice<FunctionArg> args, 395 std::vector<FunctionRet>* rets, 396 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items, 397 FunctionLibraryRuntime::DoneCallback done) const; 398 399 void CleanUp(std::vector<std::unique_ptr<CleanUpItem>>* items, 400 FunctionLibraryRuntime::DoneCallback done) const; 401 402 // Data structure holding information for a single instantiated remote 403 // (to be executed on `target_device`) function. 404 class FunctionData { 405 public: FunctionData(const string & target_device,FunctionLibraryRuntime::LocalHandle local_handle,const string & function_key)406 FunctionData(const string& target_device, 407 FunctionLibraryRuntime::LocalHandle local_handle, 408 const string& function_key) 409 : target_device_(target_device), 410 local_handle_(local_handle), 411 function_key_(function_key) {} 412 target_device()413 const string& target_device() { return target_device_; } function_key()414 const string& function_key() { return function_key_; } 415 local_handle()416 FunctionLibraryRuntime::LocalHandle local_handle() { 417 mutex_lock l(mu_); 418 return local_handle_; 419 } 420 421 // Initializes the FunctionData object by potentially making an Initialize 422 // call to the DistributedFunctionLibraryRuntime. 423 void DistributedInit( 424 DistributedFunctionLibraryRuntime* parent, const string& function_name, 425 const FunctionLibraryDefinition& lib_def, AttrSlice attrs, 426 const FunctionLibraryRuntime::InstantiateOptions& options, 427 FunctionLibraryRuntime::DoneCallback done); 428 is_cross_process()429 bool is_cross_process() { 430 mutex_lock l(mu_); 431 return is_cross_process_; 432 } 433 434 private: 435 mutex mu_; 436 437 const string target_device_; 438 FunctionLibraryRuntime::LocalHandle local_handle_ TF_GUARDED_BY(mu_); 439 const string function_key_; 440 bool is_cross_process_ TF_GUARDED_BY(mu_) = false; 441 bool init_started_ TF_GUARDED_BY(mu_) = false; 442 Status init_result_ TF_GUARDED_BY(mu_); 443 Notification init_done_; 444 }; 445 446 mutable mutex mu_; 447 448 Env* const env_; 449 const absl::optional<const ConfigProto> config_; 450 const DeviceMgr* const device_mgr_; 451 const FunctionLibraryDefinition* lib_def_; 452 thread::ThreadPool* default_thread_pool_; 453 454 // Cluster update can reinitialize the device_set_ due to remote device 455 // changes. At the same time, InstantiateMultiDevice can use the cached 456 // devices to instantiate multi-worker functions. Function instantiation would 457 // fail if it spans the changed remote devices. 458 std::shared_ptr<DeviceSet> device_set_ TF_GUARDED_BY(mu_); 459 460 // Composite devices owned by a EagerContext. 461 std::vector<CompositeDevice*> composite_devices_ TF_GUARDED_BY(mu_); 462 463 // Holds all the function instantiations. Maps function_keys to handles. 464 std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ 465 TF_GUARDED_BY(mu_); 466 467 // Function data for instantiated remote functions. 468 std::unordered_map<FunctionLibraryRuntime::Handle, 469 std::unique_ptr<FunctionData>> 470 function_data_ TF_GUARDED_BY(mu_); 471 472 // Function data for instantiated multi-device functions. 473 std::unordered_map<FunctionLibraryRuntime::Handle, 474 std::unique_ptr<MultiDeviceFunctionData>> 475 mdevice_data_ TF_GUARDED_BY(mu_); 476 477 std::unique_ptr< 478 std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>>> 479 flr_map_; 480 int next_handle_ TF_GUARDED_BY(mu_); 481 const SessionMetadata* const session_metadata_; 482 const Rendezvous::Factory rendezvous_factory_; 483 484 const OptimizerOptions optimizer_options_; 485 const int graph_def_version_; 486 }; 487 488 } // namespace tensorflow 489 490 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 491