1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 17 18 #include <unordered_map> 19 20 // clang-format off 21 // Required for IS_MOBILE_PLATFORM 22 #include "tensorflow/core/platform/platform.h" 23 // clang-format on 24 25 #include "absl/types/optional.h" 26 #include "absl/types/variant.h" 27 #include "tensorflow/core/common_runtime/composite_device.h" 28 #include "tensorflow/core/common_runtime/device_mgr.h" 29 #include "tensorflow/core/common_runtime/device_set.h" 30 #include "tensorflow/core/framework/function.h" 31 #include "tensorflow/core/framework/types.h" 32 #include "tensorflow/core/lib/core/status.h" 33 #include "tensorflow/core/protobuf/config.pb.h" 34 #if !defined(IS_MOBILE_PLATFORM) 35 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 36 #endif // IS_MOBILE_PLATFORM 37 38 namespace tensorflow { 39 40 class FunctionArgsInterface { 41 public: ~FunctionArgsInterface()42 virtual ~FunctionArgsInterface() {} 43 44 virtual bool HasRemoteOrPackedInputs() const = 0; 45 46 virtual Status GetLocalArg(const FunctionArgIndex& index, 47 Tensor* val) const = 0; 48 49 virtual std::vector<Tensor> GetLocalTensors() const = 0; 50 51 #if !defined(IS_MOBILE_PLATFORM) GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val)52 virtual Status GetRemoteArg(const FunctionArgIndex& index, 53 eager::RemoteTensorHandle* val) const { 54 return errors::Unimplemented( 55 "Serializing a remote argument is not implemented."); 56 } 57 #endif // IS_MOBILE_PLATFORM 58 }; 59 60 // A class that stores all the FunctionLibraryRuntime objects, one per device. 61 class ProcessFunctionLibraryRuntime { 62 public: 63 // Creates FunctionLibraryRuntime objects for each device in the provided 64 // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent 65 // (if provided) outlive this object. 66 ProcessFunctionLibraryRuntime( 67 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, 68 int graph_def_version, const FunctionLibraryDefinition* lib_def, 69 const OptimizerOptions& optimizer_options, 70 thread::ThreadPool* thread_pool = nullptr, 71 DistributedFunctionLibraryRuntime* parent = nullptr, 72 const SessionMetadata* session_metadata = nullptr, 73 Rendezvous::Factory rendezvous_factory = Rendezvous::Factory()); 74 ~ProcessFunctionLibraryRuntime()75 ~ProcessFunctionLibraryRuntime() { 76 // Deleting the FunctionLibraryRuntime map will delete the function handles 77 // registered in it, which may call ReleaseHandle in this class again to 78 // release their sub-function. These circular calls may cause segfault 79 // since the flr_map_ may have already been deleted. Explicitly releasing 80 // flr_map_ here and checking flr_map_ in ReleaseHandle to avoid this. 81 flr_map_.reset(); 82 } 83 84 // Sends `tensors_to_send` from `source_device` to `target_device` using 85 // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the 86 // Rendezvous. `device_context` should be the DeviceContext of the device 87 // doing the sending. `alloc_attrs` should either be empty or be the size of 88 // `tensors_to_send` and indicates how the input tensors are allocated. Method 89 // takes references on each of the `tensors_to_send`. Method doesn't block. 90 static Status SendTensors(const string& source_device, 91 const string& target_device, 92 const string& key_prefix, int64 src_incarnation, 93 gtl::ArraySlice<Tensor> tensors_to_send, 94 DeviceContext* device_context, 95 const std::vector<AllocatorAttributes>& alloc_attrs, 96 RendezvousInterface* rendezvous); 97 98 // Receives `received_tensors` from `target_device` (originally sent from 99 // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the 100 // keys to be retrieved. `device_context` should be for the device receiving 101 // the tensors. `alloc_attrs` indicates how to allocate the received 102 // tensors and should either be empty or `num_tensors` in size. Method doesn't 103 // block and calls `done` when `num_tensors` are fetched. 104 static void ReceiveTensorsAsync( 105 const string& source_device, const string& target_device, 106 const string& key_prefix, int64 src_incarnation, int64 num_tensors, 107 DeviceContext* device_context, 108 const std::vector<AllocatorAttributes>& alloc_attrs, 109 RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors, 110 StatusCallback done); 111 112 static const char kDefaultFLRDevice[]; 113 // Returns the FunctionLibraryRuntime for the corresponding device_name. 114 FunctionLibraryRuntime* GetFLR(const string& device_name) const; 115 116 // Returns the return types for the function identified by handle `h`. 117 Status GetRetTypes(FunctionLibraryRuntime::Handle h, 118 DataTypeVector* ret_types); 119 120 // Returns the device incarnation for the given device_name. 121 Status GetDeviceIncarnation(const string& device_name, 122 int64* incarnation) const; 123 124 // For a given canonicalized key signature of the function instantiated 125 // on device `device_name` and a `local_handle`, creates a handle and returns 126 // that value. Uses core/common_runtime/framework/function.h::Canonicalize 127 // to canonicalize the function signature. 128 FunctionLibraryRuntime::Handle AddHandle( 129 const string& function_key, const string& device_name, 130 FunctionLibraryRuntime::LocalHandle local_handle); 131 132 // Returns a handle if found for the given key, else returns kInvalidHandle. 133 FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; 134 135 // For the given handle instantiated on device `device_name` returns the local 136 // index of instantiation of that function. If the function was not 137 // instantiated on `device_name` or the function is multi-device, 138 // returns kInvalidLocalHandle. 139 // 140 // If `include_multi_device` is true and `handle` is a multi-device function 141 // with a single component that is placed on `device_name`, then this method 142 // will return the local handle for that component. 143 FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( 144 const string& device_name, FunctionLibraryRuntime::Handle handle, 145 bool include_multi_device = false) const; 146 147 // Fills `output_devices` with the devices on which the results will 148 // be produced. If some output is produced on CPU, the corresponding Device* 149 // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device* 150 // is set to the device backing the resource. 151 // REQUIRES: `handle` identifies a multi-device function. 152 Status GetOutputDevices(FunctionLibraryRuntime::Handle handle, 153 std::vector<Device*>* output_devices) const; 154 155 // Returns true if function with handle `handle` was instantiated on device 156 // `device_name`. Returns false for multi-device functions. 157 bool IsInstantiatedOnDevice(const string& device_name, 158 FunctionLibraryRuntime::Handle handle) const; 159 160 // Instantiates the function. See framework/function.h for more details. 161 // Allows for function_name to be instantiated on different devices 162 // as specified in attrs. 163 Status Instantiate(const string& function_name, AttrSlice attrs, 164 const FunctionLibraryRuntime::InstantiateOptions& options, 165 FunctionLibraryRuntime::Handle* handle); 166 167 // Returns whether the function represented by the given handle needs to 168 // execute cross process. 169 Status IsCrossProcess(FunctionLibraryRuntime::Handle handle, 170 bool* is_cross_process) const; 171 172 // Delegates to the local FLR that owns state corresponding to `handle` and 173 // tells it to release it. If the `handle` isn't needed at all, the local FLR 174 // might call RemoveHandle on this to get rid of the state owned by the Proc 175 // FLR. 176 // For multi-device functions, calls ReleaseHandle on local FLRs for each 177 // component function that is part of this multi-device function. 178 // Each local FLR might call RemoveHandle on this. 179 Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); 180 181 // Runs the function with given `handle`. Function could have been 182 // instantiated on any device. More details in framework/function.h 183 void Run(const FunctionLibraryRuntime::Options& opts, 184 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, 185 std::vector<Tensor>* rets, 186 FunctionLibraryRuntime::DoneCallback done) const; 187 void Run(const FunctionLibraryRuntime::Options& opts, 188 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame, 189 FunctionLibraryRuntime::DoneCallback done) const; 190 191 void Run(const FunctionLibraryRuntime::Options& opts, 192 FunctionLibraryRuntime::Handle handle, 193 const FunctionArgsInterface& args, std::vector<FunctionRet>* rets, 194 FunctionLibraryRuntime::DoneCallback done) const; 195 196 Status RunSync(const FunctionLibraryRuntime::Options& opts, 197 FunctionLibraryRuntime::Handle handle, 198 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const; 199 Status RunSync(const FunctionLibraryRuntime::Options& opts, 200 FunctionLibraryRuntime::Handle handle, 201 CallFrameInterface* frame) const; 202 device_mgr()203 const DeviceMgr* device_mgr() { return device_mgr_; } 204 device_set()205 const std::shared_ptr<DeviceSet> device_set() const { 206 tf_shared_lock l(mu_); 207 return device_set_; 208 } 209 210 // Initialize the set of local and remote devices and corresponding flr for op 211 // device selection. 212 void InitializeDeviceAndFlr(); 213 config()214 const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; } 215 GetFunctionLibraryDefinition()216 const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const { 217 return lib_def_; 218 } 219 220 // Add a CompositeDevice to `device_set_` AddCompositeDevice(CompositeDevice * d)221 void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) { 222 mutex_lock l(mu_); 223 device_set_->AddDevice(d); 224 composite_devices_.push_back(d); 225 } 226 227 protected: 228 friend class FunctionLibraryRuntimeImpl; 229 230 struct InternalArgs { 231 std::vector<FunctionArg> args; 232 #if !defined(IS_MOBILE_PLATFORM) 233 // Holds the RemoteTensorHandles referred by args. 234 std::vector<std::unique_ptr<eager::RemoteTensorHandle>> remote_args; 235 #endif // IS_MOBILE_PLATFORM 236 }; 237 238 // Structure to keep track of how a component function (a single-device 239 // piece of a multi-device function) fits into the multi-device function. 240 struct ComponentFunctionData { 241 // The handle for the instantiated component function. 242 FunctionLibraryRuntime::Handle handle; 243 // arg_indices.size() is the number of arguments to the component function. 244 // The i-th argument of the component function comes from the 245 // `arg_indices[i]`-th argument of the multi-device function. 246 std::vector<FunctionArgIndex> arg_indices; 247 // ret_indices.size() is the number of return values of the component 248 // function. The i-th return value of the component function goes to the 249 // `ret_indices[i]`-th return value of the multi-device function. 250 std::vector<int> ret_indices; 251 // arg_alloc_attrs[i] are the allocator attributes of the i-th argument to 252 // the component function. 253 std::vector<AllocatorAttributes> arg_alloc_attrs; 254 // ret_alloc_attrs[i] are the allocator attributes of the i-th return value 255 // of the component function. 256 std::vector<AllocatorAttributes> ret_alloc_attrs; 257 }; 258 259 // Data structure holding information for a single instantiated multi-device 260 // function. 261 // The fields are filled in during instantiation. Once the object is 262 // added to mdevice_data_, all fields are constant. 263 struct MultiDeviceFunctionData { MultiDeviceFunctionDataMultiDeviceFunctionData264 MultiDeviceFunctionData(const string& function_name, 265 const string& function_key, int num_outputs, 266 FunctionLibraryDefinition&& lib_def, 267 DataTypeVector ret_types) 268 : function_name_(function_name), 269 function_key_(function_key), 270 instantiation_counter_(1), 271 lib_def_(std::move(lib_def)), 272 num_outputs_(num_outputs), 273 ret_types_(std::move(ret_types)), 274 is_cross_process_(false), 275 has_remote_outputs(false) {} 276 277 const string function_name_; 278 const string function_key_; 279 uint64 instantiation_counter_; 280 // A library that contains definitions of component functions and their 281 // transitive dependencies. 282 FunctionLibraryDefinition lib_def_; 283 // Stored here to resize the output tensor vector when function is run. 284 const int num_outputs_; 285 DataTypeVector ret_types_; 286 287 // Indicates whether this function needs to execute cross process. 288 bool is_cross_process_; 289 // Indicates whether this function has remote outputs. 290 bool has_remote_outputs; 291 292 // Maps the device name to the information about the component function 293 // be run on this device. 294 std::unordered_map<string, ComponentFunctionData> glue_; 295 }; 296 297 struct CleanUpItem { 298 string device; 299 uint64 step_id; 300 FunctionLibraryRuntime::Handle local_handle; 301 }; 302 303 // If `handle` represents a multi-device function, returns the multi-device 304 // data associated with `handle`. Else, nullptr. 305 MultiDeviceFunctionData* IsMultiDevice( 306 FunctionLibraryRuntime::Handle handle) const; 307 308 void RunMultiDevice( 309 const FunctionLibraryRuntime::Options& opts, 310 FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets, 311 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items, 312 FunctionLibraryRuntime::DoneCallback done, 313 std::function<Status(const ComponentFunctionData& comp_data, 314 InternalArgs* args)> 315 get_component_args) const; 316 317 Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts, 318 Rendezvous** created_rendezvous) const; 319 320 FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback( 321 std::vector<std::unique_ptr<CleanUpItem>>* items, 322 FunctionLibraryRuntime::DoneCallback done, const int64 step_id, 323 const Rendezvous* rendezvous) const; 324 325 DistributedFunctionLibraryRuntime* const parent_; 326 327 private: 328 FunctionLibraryRuntime::Handle AddHandleLocked( 329 const string& function_key, const string& device_name, 330 FunctionLibraryRuntime::LocalHandle local_handle) 331 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 332 333 // For a given device_name, returns a DeviceContext for copying 334 // tensors to/from the device. 335 Status GetDeviceContext(const string& device_name, 336 DeviceContext** device_context) const; 337 338 // Looks up the information for the given `handle` and returns the name 339 // of the device where the function is registered. 340 string GetDeviceName(FunctionLibraryRuntime::Handle handle) const; 341 342 // Removes handle from the state owned by this object. 343 Status RemoveHandle(FunctionLibraryRuntime::Handle handle); 344 345 // Clones ProcessFunctionLibraryRuntime and FunctionLibraryDefinition 346 // (transferring ownership of both to the caller). Note that the 347 // ProcessFunctionLibraryRuntime borrows a pointer to the 348 // FunctionLibraryDefinition and so the FunctionLibraryDefinition should 349 // outlive the ProcessFunctionLibraryRuntime. 350 // 351 // The `skip_flib_def` argument controls whether the method should clone the 352 // FunctionLibraryDefinition (default behavior) or return an empty function 353 // library. The latter is used by tf.data, which manages 354 // FunctionLibraryDefinitions for its functions independently (and passes 355 // these into the FunctionLibraryRuntime through an overlay), to avoid linear 356 // runtime w.r.t. to number of functions in the current function library. 357 Status Clone(Env* env, int graph_def_version, 358 const OptimizerOptions& optimizer_options, 359 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 360 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, 361 bool skip_flib_def = false) const; 362 363 Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle); 364 365 Status InstantiateMultiDevice( 366 const string& function_name, AttrSlice attrs, 367 const FunctionLibraryRuntime::InstantiateOptions& options, 368 FunctionLibraryRuntime::Handle* handle); 369 370 void InstantiateRemote( 371 const string& function_name, AttrSlice attrs, 372 const FunctionLibraryRuntime::InstantiateOptions& options, 373 FunctionLibraryRuntime::Handle* handle, 374 FunctionLibraryRuntime::DoneCallback done); 375 376 FunctionLibraryRuntime::Handle AddMultiDeviceHandle( 377 const std::unique_ptr<MultiDeviceFunctionData> data, 378 const string& function_key); 379 380 // TODO(iga): Reword 381 // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the 382 // corresponding resource lives. This ensures that the Placer assigns ops that 383 // access these resources to the appropriate devices. 384 Status PinArgsAndRets(const std::vector<string>& input_devices, 385 const std::vector<string>& output_devices, 386 const DeviceSet& device_set, 387 const std::vector<Node*>& arg_nodes, 388 const std::vector<Node*>& ret_nodes, 389 Device* default_device) const; 390 391 void RunInternal(const FunctionLibraryRuntime::Options& opts, 392 FunctionLibraryRuntime::Handle handle, 393 gtl::ArraySlice<FunctionArg> args, 394 std::vector<FunctionRet>* rets, 395 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items, 396 FunctionLibraryRuntime::DoneCallback done) const; 397 398 void CleanUp(std::vector<std::unique_ptr<CleanUpItem>>* items, 399 FunctionLibraryRuntime::DoneCallback done) const; 400 401 // Data structure holding information for a single instantiated remote 402 // (to be executed on `target_device`) function. 403 class FunctionData { 404 public: FunctionData(const string & target_device,FunctionLibraryRuntime::LocalHandle local_handle,const string & function_key)405 FunctionData(const string& target_device, 406 FunctionLibraryRuntime::LocalHandle local_handle, 407 const string& function_key) 408 : target_device_(target_device), 409 local_handle_(local_handle), 410 function_key_(function_key) {} 411 target_device()412 const string& target_device() { return target_device_; } function_key()413 const string& function_key() { return function_key_; } 414 local_handle()415 FunctionLibraryRuntime::LocalHandle local_handle() { 416 mutex_lock l(mu_); 417 return local_handle_; 418 } 419 420 // Initializes the FunctionData object by potentially making an Initialize 421 // call to the DistributedFunctionLibraryRuntime. 422 void DistributedInit( 423 DistributedFunctionLibraryRuntime* parent, const string& function_name, 424 const FunctionLibraryDefinition& lib_def, AttrSlice attrs, 425 const FunctionLibraryRuntime::InstantiateOptions& options, 426 FunctionLibraryRuntime::DoneCallback done); 427 is_cross_process()428 bool is_cross_process() { 429 mutex_lock l(mu_); 430 return is_cross_process_; 431 } 432 433 private: 434 mutex mu_; 435 436 const string target_device_; 437 FunctionLibraryRuntime::LocalHandle local_handle_ TF_GUARDED_BY(mu_); 438 const string function_key_; 439 bool is_cross_process_ TF_GUARDED_BY(mu_) = false; 440 bool init_started_ TF_GUARDED_BY(mu_) = false; 441 Status init_result_ TF_GUARDED_BY(mu_); 442 Notification init_done_; 443 }; 444 445 mutable mutex mu_; 446 447 Env* const env_; 448 const absl::optional<const ConfigProto> config_; 449 const DeviceMgr* const device_mgr_; 450 const FunctionLibraryDefinition* lib_def_; 451 thread::ThreadPool* default_thread_pool_; 452 453 // Cluster update can reinitialize the device_set_ due to remote device 454 // changes. At the same time, InstantiateMultiDevice can use the cached 455 // devices to instantiate multi-worker functions. Function instantiation would 456 // fail if it spans the changed remote devices. 457 std::shared_ptr<DeviceSet> device_set_ TF_GUARDED_BY(mu_); 458 459 // Composite devices owned by a EagerContext. 460 std::vector<CompositeDevice*> composite_devices_ TF_GUARDED_BY(mu_); 461 462 // Holds all the function instantiations. Maps function_keys to handles. 463 std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ 464 TF_GUARDED_BY(mu_); 465 466 // Function data for instantiated remote functions. 467 std::unordered_map<FunctionLibraryRuntime::Handle, 468 std::unique_ptr<FunctionData>> 469 function_data_ TF_GUARDED_BY(mu_); 470 471 // Function data for instantiated multi-device functions. 472 std::unordered_map<FunctionLibraryRuntime::Handle, 473 std::unique_ptr<MultiDeviceFunctionData>> 474 mdevice_data_ TF_GUARDED_BY(mu_); 475 476 std::unique_ptr< 477 std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>>> 478 flr_map_; 479 int next_handle_ TF_GUARDED_BY(mu_); 480 const SessionMetadata* const session_metadata_; 481 const Rendezvous::Factory rendezvous_factory_; 482 483 const OptimizerOptions optimizer_options_; 484 const int graph_def_version_; 485 }; 486 487 } // namespace tensorflow 488 489 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 490