1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 17 18 #include <unordered_map> 19 20 #include "tensorflow/core/common_runtime/device_mgr.h" 21 #include "tensorflow/core/framework/function.h" 22 #include "tensorflow/core/protobuf/config.pb.h" 23 24 namespace tensorflow { 25 26 // A class that stores all the FunctionLibraryRuntime objects, one per device. 27 class ProcessFunctionLibraryRuntime { 28 public: 29 // Creates FunctionLibraryRuntime objects for each device in the provided 30 // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent 31 // (if provided) outlive this object. 32 ProcessFunctionLibraryRuntime( 33 const DeviceMgr* device_mgr, Env* env, int graph_def_version, 34 const FunctionLibraryDefinition* lib_def, 35 const OptimizerOptions& optimizer_options, 36 DistributedFunctionLibraryRuntime* parent = nullptr); 37 38 // With `custom_kernel_creator`. 39 ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, 40 int graph_def_version, 41 const FunctionLibraryDefinition* lib_def, 42 const OptimizerOptions& optimizer_options, 43 CustomKernelCreator custom_kernel_creator, 44 DistributedFunctionLibraryRuntime* parent); 45 46 // Sends `tensors_to_send` from `source_device` to `target_device` using 47 // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the 48 // Rendezvous. `device_context` should be the DeviceContext of the device 49 // doing the sending. `alloc_attrs` should either be empty or be the size of 50 // `tensors_to_send` and indicates how the input tensors are allocated. Method 51 // takes references on each of the `tensors_to_send`. Method doesn't block. 52 static Status SendTensors(const string& source_device, 53 const string& target_device, 54 const string& key_prefix, int64 src_incarnation, 55 gtl::ArraySlice<Tensor> tensors_to_send, 56 DeviceContext* device_context, 57 const std::vector<AllocatorAttributes>& alloc_attrs, 58 Rendezvous* rendezvous); 59 60 typedef std::function<void(const Status&)> StatusCallback; 61 62 // Receives `received_tensors` from `target_device` (originally sent from 63 // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the 64 // keys to be retrieved. `device_context` should be for the device receiving 65 // the tensors. `alloc_attrs` indicates how to allocate the received 66 // tensors and should either be empty or `num_tensors` in size. Method doesn't 67 // block and calls `done` when `num_tensors` are fetched. 68 static void ReceiveTensorsAsync( 69 const string& source_device, const string& target_device, 70 const string& key_prefix, int64 src_incarnation, int64 num_tensors, 71 DeviceContext* device_context, 72 const std::vector<AllocatorAttributes>& alloc_attrs, 73 Rendezvous* rendezvous, std::vector<Tensor>* received_tensors, 74 const StatusCallback& done); 75 76 static const char kDefaultFLRDevice[]; 77 // Returns the FunctionLibraryRuntime for the corresponding device_name. 78 FunctionLibraryRuntime* GetFLR(const string& device_name) const; 79 80 // Returns the device incarnation for the given device_name. 81 Status GetDeviceIncarnation(const string& device_name, int64* incarnation); 82 83 // For a given canonicalized key signature of the function instantiated 84 // on device `device_name` and a `local_handle`, creates a handle and returns 85 // that value. Uses core/common_runtime/framework/function.h::Canonicalize 86 // to canonicalize the function signature. 87 FunctionLibraryRuntime::Handle AddHandle( 88 const string& function_key, const string& device_name, 89 FunctionLibraryRuntime::LocalHandle local_handle); 90 91 // Returns a handle if found for the given key, else returns kInvalidHandle. 92 FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; 93 94 // For the given handle instantiated on device `device_name` returns the local 95 // index of instantiation of that function. If the function was not 96 // instantiated on `device_name` returns kInvalidLocalHandle. 97 FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( 98 const string& device_name, FunctionLibraryRuntime::Handle handle); 99 100 // Returns true if function with handle `handle` was instantiated on device 101 // `device_name`. 102 bool IsInstantiatedOnDevice(const string& device_name, 103 FunctionLibraryRuntime::Handle handle); 104 105 // Instantiates the function. See framework/function.h for more details. 106 // Allows for function_name to be instantiated on different devices 107 // as specified in attrs. 108 Status Instantiate(const string& function_name, AttrSlice attrs, 109 const FunctionLibraryRuntime::InstantiateOptions& options, 110 FunctionLibraryRuntime::Handle* handle); 111 112 // Delegates to the local FLR that owns state corresponding to `handle` and 113 // tells it to release it. If the `handle` isnt' needed at all, the local FLR 114 // might call RemoveHandle on this to get rid of the state owned by the Proc 115 // FLR. 116 Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); 117 118 // Runs the function with given `handle`. Function could have been 119 // instantiated on any device. More details in framework/function.h 120 void Run(const FunctionLibraryRuntime::Options& opts, 121 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, 122 std::vector<Tensor>* rets, 123 FunctionLibraryRuntime::DoneCallback done); 124 125 private: 126 // For a given device_name, returns a DeviceContext for copying 127 // tensors to/from the device. 128 Status GetDeviceContext(const string& device_name, 129 DeviceContext** device_context); 130 131 // Looks up the information for the given `handle` and returns the name 132 // of the device where the function is registered. 133 string GetDeviceName(FunctionLibraryRuntime::Handle handle); 134 135 // Removes handle from the state owned by this object. 136 Status RemoveHandle(FunctionLibraryRuntime::Handle handle); 137 138 Status Clone(Env* env, int graph_def_version, 139 const OptimizerOptions& optimizer_options, 140 CustomKernelCreator custom_kernel_creator, 141 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 142 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr); 143 144 friend class FunctionLibraryRuntimeImpl; 145 146 mutable mutex mu_; 147 148 struct FunctionData { 149 const string target_device; 150 const FunctionLibraryRuntime::LocalHandle local_handle; 151 FunctionDataFunctionData152 FunctionData(const string& target_device, 153 FunctionLibraryRuntime::LocalHandle local_handle) 154 : target_device(target_device), local_handle(local_handle) {} FunctionDataFunctionData155 FunctionData() : FunctionData("", -1) {} 156 }; 157 158 const DeviceMgr* const device_mgr_; 159 const FunctionLibraryDefinition* lib_def_; 160 // Holds all the function invocations here. 161 std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ 162 GUARDED_BY(mu_); 163 std::unordered_map<FunctionLibraryRuntime::Handle, FunctionData> 164 function_data_ GUARDED_BY(mu_); 165 std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_; 166 int next_handle_ GUARDED_BY(mu_); 167 DistributedFunctionLibraryRuntime* const parent_; 168 }; 169 170 } // namespace tensorflow 171 172 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ 173