1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 17 18 #include <unordered_map> 19 20 #include "tensorflow/core/common_runtime/device_mgr.h" 21 #include "tensorflow/core/common_runtime/device_set.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/protobuf/config.pb.h" 25 26 namespace tensorflow { 27 28 // A class that stores all the FunctionLibraryRuntime objects, one per device. 29 class ProcessFunctionLibraryRuntime { 30 public: 31 // Creates FunctionLibraryRuntime objects for each device in the provided 32 // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent 33 // (if provided) outlive this object. 34 ProcessFunctionLibraryRuntime( 35 const DeviceMgr* device_mgr, Env* env, int graph_def_version, 36 const FunctionLibraryDefinition* lib_def, 37 const OptimizerOptions& optimizer_options, 38 thread::ThreadPool* thread_pool = nullptr, 39 DistributedFunctionLibraryRuntime* parent = nullptr); 40 41 // With `custom_kernel_creator`. 42 ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, 43 int graph_def_version, 44 const FunctionLibraryDefinition* lib_def, 45 const OptimizerOptions& optimizer_options, 46 CustomKernelCreator custom_kernel_creator, 47 thread::ThreadPool* thread_pool, 48 DistributedFunctionLibraryRuntime* parent); 49 50 // Sends `tensors_to_send` from `source_device` to `target_device` using 51 // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the 52 // Rendezvous. `device_context` should be the DeviceContext of the device 53 // doing the sending. `alloc_attrs` should either be empty or be the size of 54 // `tensors_to_send` and indicates how the input tensors are allocated. Method 55 // takes references on each of the `tensors_to_send`. Method doesn't block. 56 static Status SendTensors(const string& source_device, 57 const string& target_device, 58 const string& key_prefix, int64 src_incarnation, 59 gtl::ArraySlice<Tensor> tensors_to_send, 60 DeviceContext* device_context, 61 const std::vector<AllocatorAttributes>& alloc_attrs, 62 Rendezvous* rendezvous); 63 64 // Receives `received_tensors` from `target_device` (originally sent from 65 // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the 66 // keys to be retrieved. `device_context` should be for the device receiving 67 // the tensors. `alloc_attrs` indicates how to allocate the received 68 // tensors and should either be empty or `num_tensors` in size. Method doesn't 69 // block and calls `done` when `num_tensors` are fetched. 70 static void ReceiveTensorsAsync( 71 const string& source_device, const string& target_device, 72 const string& key_prefix, int64 src_incarnation, int64 num_tensors, 73 DeviceContext* device_context, 74 const std::vector<AllocatorAttributes>& alloc_attrs, 75 Rendezvous* rendezvous, std::vector<Tensor>* received_tensors, 76 StatusCallback done); 77 78 static const char kDefaultFLRDevice[]; 79 // Returns the FunctionLibraryRuntime for the corresponding device_name. 80 FunctionLibraryRuntime* GetFLR(const string& device_name) const; 81 82 // Returns the device incarnation for the given device_name. 83 Status GetDeviceIncarnation(const string& device_name, 84 int64* incarnation) const; 85 86 // For a given canonicalized key signature of the function instantiated 87 // on device `device_name` and a `local_handle`, creates a handle and returns 88 // that value. Uses core/common_runtime/framework/function.h::Canonicalize 89 // to canonicalize the function signature. 90 FunctionLibraryRuntime::Handle AddHandle( 91 const string& function_key, const string& device_name, 92 FunctionLibraryRuntime::LocalHandle local_handle); 93 94 // Returns a handle if found for the given key, else returns kInvalidHandle. 95 FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; 96 97 // For the given handle instantiated on device `device_name` returns the local 98 // index of instantiation of that function. If the function was not 99 // instantiated on `device_name` or the function is multi-device, 100 // returns kInvalidLocalHandle. 101 FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( 102 const string& device_name, FunctionLibraryRuntime::Handle handle) const; 103 104 // Fills `output_devices` with the devices on which the results will 105 // be produced. If some output is produced on CPU, the corresponding Device* 106 // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device* 107 // is set to the device backing the resource. 108 // REQUIRES: `handle` identifies a multi-device function. 109 Status GetOutputDevices(FunctionLibraryRuntime::Handle handle, 110 std::vector<Device*>* output_devices) const; 111 112 // Returns true if function with handle `handle` was instantiated on device 113 // `device_name`. Returns false for multi-device functions. 114 bool IsInstantiatedOnDevice(const string& device_name, 115 FunctionLibraryRuntime::Handle handle) const; 116 117 // Instantiates the function. See framework/function.h for more details. 118 // Allows for function_name to be instantiated on different devices 119 // as specified in attrs. 120 Status Instantiate(const string& function_name, AttrSlice attrs, 121 const FunctionLibraryRuntime::InstantiateOptions& options, 122 FunctionLibraryRuntime::Handle* handle); 123 124 // Delegates to the local FLR that owns state corresponding to `handle` and 125 // tells it to release it. If the `handle` isnt' needed at all, the local FLR 126 // might call RemoveHandle on this to get rid of the state owned by the Proc 127 // FLR. 128 // For multi-device functions, calls ReleaseHandle on local FLRs for each 129 // component function that is part of this multi-device function. 130 // Each local FLR might call RemoveHandle on this. 131 Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); 132 133 // Runs the function with given `handle`. Function could have been 134 // instantiated on any device. More details in framework/function.h 135 void Run(const FunctionLibraryRuntime::Options& opts, 136 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, 137 std::vector<Tensor>* rets, 138 FunctionLibraryRuntime::DoneCallback done) const; 139 device_mgr()140 const DeviceMgr* device_mgr() { return device_mgr_; } 141 142 private: 143 friend class FunctionLibraryRuntimeImpl; 144 145 using DeviceAndFHandle = std::pair<string, FunctionLibraryRuntime::Handle>; 146 using ArgAndRetIndices = std::pair<std::vector<int>, std::vector<int>>; 147 using ArgAndRetAllocAttrs = std::pair<std::vector<AllocatorAttributes>, 148 std::vector<AllocatorAttributes>>; 149 150 FunctionLibraryRuntime::Handle AddHandleLocked( 151 const string& function_key, const string& device_name, 152 FunctionLibraryRuntime::LocalHandle local_handle) 153 EXCLUSIVE_LOCKS_REQUIRED(mu_); 154 155 // Structure to keep track of how a component function (a single-device 156 // piece of a multi-device function) fits into the multi-device function. 157 struct ComponentFunctionData { 158 // The handle for the instantiated component function. 159 FunctionLibraryRuntime::Handle handle_; 160 // arg_indices_.size() is the number of arguments to the component function. 161 // The i'th argument of the component function comes from the 162 // `arg_indices_[i]`th argument of the multi-device function. 163 std::vector<int> arg_indices_; 164 // ret_indices_.size() is the number of return value of the component 165 // function. The i'th return value of the component function goes to the 166 // `ret_indices_[i]`th return value of the multi-device function. 167 std::vector<int> ret_indices_; 168 // arg_alloc_attrs_[i] are the allocator attributes of the i'th argument to 169 // the component function. 170 std::vector<AllocatorAttributes> arg_alloc_attrs_; 171 // ret_alloc_attrs_[i] are the allocator attributes of the i'th return value 172 // of the component function. 173 std::vector<AllocatorAttributes> ret_alloc_attrs_; 174 }; 175 176 // Data structure holding information for a single instantiated multi-device 177 // function. 178 // The fields are filled in during instantiation. Once the object is 179 // added to mdevice_data_, all fields are constant. 180 struct MultiDeviceFunctionData { MultiDeviceFunctionDataMultiDeviceFunctionData181 MultiDeviceFunctionData(const string& function_name, 182 const string& function_key, int num_outputs, 183 const FunctionLibraryDefinition& overlay_lib) 184 : num_outputs_(num_outputs), 185 instantiation_counter_(1), 186 function_name_(function_name), 187 function_key_(function_key), 188 overlay_lib_(overlay_lib) {} 189 190 // Stored here to resize the output tensor vector when function is run. 191 const int num_outputs_; 192 uint64 instantiation_counter_; 193 const string function_name_; 194 const string function_key_; 195 // The overlay library holding component function definitions as well as 196 // the definitions of functions they call. 197 FunctionLibraryDefinition overlay_lib_; 198 199 // Maps the device name to the information about the component function 200 // be run on this device. 201 std::unordered_map<string, ComponentFunctionData> glue_; 202 }; 203 204 // For a given device_name, returns a DeviceContext for copying 205 // tensors to/from the device. 206 Status GetDeviceContext(const string& device_name, 207 DeviceContext** device_context) const; 208 209 // Looks up the information for the given `handle` and returns the name 210 // of the device where the function is registered. 211 string GetDeviceName(FunctionLibraryRuntime::Handle handle) const; 212 213 // Removes handle from the state owned by this object. 214 Status RemoveHandle(FunctionLibraryRuntime::Handle handle); 215 216 Status Clone(Env* env, int graph_def_version, 217 const OptimizerOptions& optimizer_options, 218 CustomKernelCreator custom_kernel_creator, 219 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 220 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) const; 221 222 Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle); 223 224 // If handle represents a multi-device function, returns the multi-device 225 // data associated with handle. Else, nullptr. 226 MultiDeviceFunctionData* IsMultiDevice( 227 FunctionLibraryRuntime::Handle handle) const; 228 229 Status InstantiateMultiDevice( 230 const string& function_name, AttrSlice attrs, 231 const FunctionLibraryRuntime::InstantiateOptions& options, 232 FunctionLibraryRuntime::Handle* handle); 233 234 FunctionLibraryRuntime::Handle AddMultiDeviceHandle( 235 const std::unique_ptr<MultiDeviceFunctionData> data, 236 const string& function_key); 237 238 // TODO(iga): Reword 239 // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the 240 // corresponding resource lives. This ensures that the Placer assigns ops that 241 // access these resources to the appropriate devices. 242 Status PinArgsAndRets(const std::vector<string>& input_devices, 243 const std::vector<string>& output_devices, 244 const DeviceSet& device_set, Graph* graph) const; 245 246 void RunMultiDevice(const FunctionLibraryRuntime::Options& opts, 247 FunctionLibraryRuntime::Handle handle, 248 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, 249 FunctionLibraryRuntime::DoneCallback done) const; 250 251 // Data structure holding information for a single instantiated remote 252 // (to be executed on `target_device`) function. 253 class FunctionData { 254 public: FunctionData(const string & target_device,FunctionLibraryRuntime::LocalHandle local_handle,const string & function_key)255 FunctionData(const string& target_device, 256 FunctionLibraryRuntime::LocalHandle local_handle, 257 const string& function_key) 258 : target_device_(target_device), 259 local_handle_(local_handle), 260 function_key_(function_key) {} 261 target_device()262 string target_device() { return target_device_; } function_key()263 const string& function_key() { return function_key_; } 264 local_handle()265 FunctionLibraryRuntime::LocalHandle local_handle() { 266 mutex_lock l(mu_); 267 return local_handle_; 268 } 269 270 // Initializes the FunctionData object by potentially making an Initialize 271 // call to the DistributedFunctionLibraryRuntime. 272 Status DistributedInit( 273 DistributedFunctionLibraryRuntime* parent, const string& function_name, 274 const FunctionLibraryDefinition& lib_def, AttrSlice attrs, 275 const FunctionLibraryRuntime::InstantiateOptions& options); 276 277 private: 278 mutex mu_; 279 280 const string target_device_; 281 FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_); 282 const string function_key_; 283 bool init_started_ GUARDED_BY(mu_) = false; 284 Status init_result_ GUARDED_BY(mu_); 285 Notification init_done_; 286 }; 287 288 mutable mutex mu_; 289 290 Env* const env_; 291 const DeviceMgr* const device_mgr_; 292 const FunctionLibraryDefinition* lib_def_; 293 thread::ThreadPool* default_thread_pool_; 294 295 // Holds all the function instantiations. Maps function_keys to handles. 296 std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ 297 GUARDED_BY(mu_); 298 299 // Function data for instantitated remote functions. 300 std::unordered_map<FunctionLibraryRuntime::Handle, 301 std::unique_ptr<FunctionData>> 302 function_data_ GUARDED_BY(mu_); 303 304 // Function data for instantiated multi-device functions. 305 std::unordered_map<FunctionLibraryRuntime::Handle, 306 std::unique_ptr<MultiDeviceFunctionData>> 307 mdevice_data_ GUARDED_BY(mu_); 308 309 std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_; 310 int next_handle_ GUARDED_BY(mu_); 311 DistributedFunctionLibraryRuntime* const parent_; 312 }; 313 314 } // namespace tensorflow 315 316 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 317