• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/device_set.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/config.pb.h"
25 
26 namespace tensorflow {
27 
28 // A class that stores all the FunctionLibraryRuntime objects, one per device.
29 class ProcessFunctionLibraryRuntime {
30  public:
31   // Creates FunctionLibraryRuntime objects for each device in the provided
32   // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
33   // (if provided) outlive this object.
34   ProcessFunctionLibraryRuntime(
35       const DeviceMgr* device_mgr, Env* env, int graph_def_version,
36       const FunctionLibraryDefinition* lib_def,
37       const OptimizerOptions& optimizer_options,
38       thread::ThreadPool* thread_pool = nullptr,
39       DistributedFunctionLibraryRuntime* parent = nullptr);
40 
41   // With `custom_kernel_creator`.
42   ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env,
43                                 int graph_def_version,
44                                 const FunctionLibraryDefinition* lib_def,
45                                 const OptimizerOptions& optimizer_options,
46                                 CustomKernelCreator custom_kernel_creator,
47                                 thread::ThreadPool* thread_pool,
48                                 DistributedFunctionLibraryRuntime* parent);
49 
50   // Sends `tensors_to_send` from `source_device` to `target_device` using
51   // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
52   // Rendezvous. `device_context` should be the DeviceContext of the device
53   // doing the sending. `alloc_attrs` should either be empty or be the size of
54   // `tensors_to_send` and indicates how the input tensors are allocated. Method
55   // takes references on each of the `tensors_to_send`. Method doesn't block.
56   static Status SendTensors(const string& source_device,
57                             const string& target_device,
58                             const string& key_prefix, int64 src_incarnation,
59                             gtl::ArraySlice<Tensor> tensors_to_send,
60                             DeviceContext* device_context,
61                             const std::vector<AllocatorAttributes>& alloc_attrs,
62                             Rendezvous* rendezvous);
63 
64   // Receives `received_tensors` from `target_device` (originally sent from
65   // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
66   // keys to be retrieved. `device_context` should be for the device receiving
67   // the tensors. `alloc_attrs` indicates how to allocate the received
68   // tensors and should either be empty or `num_tensors` in size. Method doesn't
69   // block and calls `done` when `num_tensors` are fetched.
70   static void ReceiveTensorsAsync(
71       const string& source_device, const string& target_device,
72       const string& key_prefix, int64 src_incarnation, int64 num_tensors,
73       DeviceContext* device_context,
74       const std::vector<AllocatorAttributes>& alloc_attrs,
75       Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
76       StatusCallback done);
77 
78   static const char kDefaultFLRDevice[];
79   // Returns the FunctionLibraryRuntime for the corresponding device_name.
80   FunctionLibraryRuntime* GetFLR(const string& device_name) const;
81 
82   // Returns the device incarnation for the given device_name.
83   Status GetDeviceIncarnation(const string& device_name,
84                               int64* incarnation) const;
85 
86   // For a given canonicalized key signature of the function instantiated
87   // on device `device_name` and a `local_handle`, creates a handle and returns
88   // that value. Uses core/common_runtime/framework/function.h::Canonicalize
89   // to canonicalize the function signature.
90   FunctionLibraryRuntime::Handle AddHandle(
91       const string& function_key, const string& device_name,
92       FunctionLibraryRuntime::LocalHandle local_handle);
93 
94   // Returns a handle if found for the given key, else returns kInvalidHandle.
95   FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const;
96 
97   // For the given handle instantiated on device `device_name` returns the local
98   // index of instantiation of that function. If the function was not
99   // instantiated on `device_name` or the function is multi-device,
100   // returns kInvalidLocalHandle.
101   FunctionLibraryRuntime::LocalHandle GetHandleOnDevice(
102       const string& device_name, FunctionLibraryRuntime::Handle handle) const;
103 
104   // Fills `output_devices` with the devices on which the results will
105   // be produced. If some output is produced on CPU, the corresponding Device*
106   // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device*
107   // is set to the device backing the resource.
108   // REQUIRES: `handle` identifies a multi-device function.
109   Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
110                           std::vector<Device*>* output_devices) const;
111 
112   // Returns true if function with handle `handle` was instantiated on device
113   // `device_name`. Returns false for multi-device functions.
114   bool IsInstantiatedOnDevice(const string& device_name,
115                               FunctionLibraryRuntime::Handle handle) const;
116 
117   // Instantiates the function. See framework/function.h for more details.
118   // Allows for function_name to be instantiated on different devices
119   // as specified in attrs.
120   Status Instantiate(const string& function_name, AttrSlice attrs,
121                      const FunctionLibraryRuntime::InstantiateOptions& options,
122                      FunctionLibraryRuntime::Handle* handle);
123 
124   // Delegates to the local FLR that owns state corresponding to `handle` and
125   // tells it to release it. If the `handle` isnt' needed at all, the local FLR
126   // might call RemoveHandle on this to get rid of the state owned by the Proc
127   // FLR.
128   // For multi-device functions, calls ReleaseHandle on local FLRs for each
129   // component function that is part of this multi-device function.
130   // Each local FLR might call RemoveHandle on this.
131   Status ReleaseHandle(FunctionLibraryRuntime::Handle handle);
132 
133   // Runs the function with given `handle`. Function could have been
134   // instantiated on any device. More details in framework/function.h
135   void Run(const FunctionLibraryRuntime::Options& opts,
136            FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
137            std::vector<Tensor>* rets,
138            FunctionLibraryRuntime::DoneCallback done) const;
139 
device_mgr()140   const DeviceMgr* device_mgr() { return device_mgr_; }
141 
142  private:
143   friend class FunctionLibraryRuntimeImpl;
144 
145   using DeviceAndFHandle = std::pair<string, FunctionLibraryRuntime::Handle>;
146   using ArgAndRetIndices = std::pair<std::vector<int>, std::vector<int>>;
147   using ArgAndRetAllocAttrs = std::pair<std::vector<AllocatorAttributes>,
148                                         std::vector<AllocatorAttributes>>;
149 
150   FunctionLibraryRuntime::Handle AddHandleLocked(
151       const string& function_key, const string& device_name,
152       FunctionLibraryRuntime::LocalHandle local_handle)
153       EXCLUSIVE_LOCKS_REQUIRED(mu_);
154 
155   // Structure to keep track of how a component function (a single-device
156   // piece of a multi-device function) fits into the multi-device function.
157   struct ComponentFunctionData {
158     // The handle for the instantiated component function.
159     FunctionLibraryRuntime::Handle handle_;
160     // arg_indices_.size() is the number of arguments to the component function.
161     // The i'th argument of the component function comes from the
162     // `arg_indices_[i]`th argument of the multi-device function.
163     std::vector<int> arg_indices_;
164     // ret_indices_.size() is the number of return value of the component
165     // function.  The i'th return value of the component function goes to the
166     // `ret_indices_[i]`th return value of the multi-device function.
167     std::vector<int> ret_indices_;
168     // arg_alloc_attrs_[i] are the allocator attributes of the i'th argument to
169     // the component function.
170     std::vector<AllocatorAttributes> arg_alloc_attrs_;
171     // ret_alloc_attrs_[i] are the allocator attributes of the i'th return value
172     // of the component function.
173     std::vector<AllocatorAttributes> ret_alloc_attrs_;
174   };
175 
176   // Data structure holding information for a single instantiated multi-device
177   // function.
178   // The fields are filled in during instantiation. Once the object is
179   // added to mdevice_data_, all fields are constant.
180   struct MultiDeviceFunctionData {
MultiDeviceFunctionDataMultiDeviceFunctionData181     MultiDeviceFunctionData(const string& function_name,
182                             const string& function_key, int num_outputs,
183                             const FunctionLibraryDefinition& overlay_lib)
184         : num_outputs_(num_outputs),
185           instantiation_counter_(1),
186           function_name_(function_name),
187           function_key_(function_key),
188           overlay_lib_(overlay_lib) {}
189 
190     // Stored here to resize the output tensor vector when function is run.
191     const int num_outputs_;
192     uint64 instantiation_counter_;
193     const string function_name_;
194     const string function_key_;
195     // The overlay library holding component function definitions as well as
196     // the definitions of functions they call.
197     FunctionLibraryDefinition overlay_lib_;
198 
199     // Maps the device name to the information about the component function
200     // be run on this device.
201     std::unordered_map<string, ComponentFunctionData> glue_;
202   };
203 
204   // For a given device_name, returns a DeviceContext for copying
205   // tensors to/from the device.
206   Status GetDeviceContext(const string& device_name,
207                           DeviceContext** device_context) const;
208 
209   // Looks up the information for the given `handle` and returns the name
210   // of the device where the function is registered.
211   string GetDeviceName(FunctionLibraryRuntime::Handle handle) const;
212 
213   // Removes handle from the state owned by this object.
214   Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
215 
216   Status Clone(Env* env, int graph_def_version,
217                const OptimizerOptions& optimizer_options,
218                CustomKernelCreator custom_kernel_creator,
219                std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
220                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) const;
221 
222   Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle);
223 
224   // If handle represents a multi-device function, returns the multi-device
225   // data associated with handle. Else, nullptr.
226   MultiDeviceFunctionData* IsMultiDevice(
227       FunctionLibraryRuntime::Handle handle) const;
228 
229   Status InstantiateMultiDevice(
230       const string& function_name, AttrSlice attrs,
231       const FunctionLibraryRuntime::InstantiateOptions& options,
232       FunctionLibraryRuntime::Handle* handle);
233 
234   FunctionLibraryRuntime::Handle AddMultiDeviceHandle(
235       const std::unique_ptr<MultiDeviceFunctionData> data,
236       const string& function_key);
237 
238   // TODO(iga): Reword
239   // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
240   // corresponding resource lives. This ensures that the Placer assigns ops that
241   // access these resources to the appropriate devices.
242   Status PinArgsAndRets(const std::vector<string>& input_devices,
243                         const std::vector<string>& output_devices,
244                         const DeviceSet& device_set, Graph* graph) const;
245 
246   void RunMultiDevice(const FunctionLibraryRuntime::Options& opts,
247                       FunctionLibraryRuntime::Handle handle,
248                       gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
249                       FunctionLibraryRuntime::DoneCallback done) const;
250 
251   // Data structure holding information for a single instantiated remote
252   // (to be executed on `target_device`) function.
253   class FunctionData {
254    public:
FunctionData(const string & target_device,FunctionLibraryRuntime::LocalHandle local_handle,const string & function_key)255     FunctionData(const string& target_device,
256                  FunctionLibraryRuntime::LocalHandle local_handle,
257                  const string& function_key)
258         : target_device_(target_device),
259           local_handle_(local_handle),
260           function_key_(function_key) {}
261 
target_device()262     string target_device() { return target_device_; }
function_key()263     const string& function_key() { return function_key_; }
264 
local_handle()265     FunctionLibraryRuntime::LocalHandle local_handle() {
266       mutex_lock l(mu_);
267       return local_handle_;
268     }
269 
270     // Initializes the FunctionData object by potentially making an Initialize
271     // call to the DistributedFunctionLibraryRuntime.
272     Status DistributedInit(
273         DistributedFunctionLibraryRuntime* parent, const string& function_name,
274         const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
275         const FunctionLibraryRuntime::InstantiateOptions& options);
276 
277    private:
278     mutex mu_;
279 
280     const string target_device_;
281     FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_);
282     const string function_key_;
283     bool init_started_ GUARDED_BY(mu_) = false;
284     Status init_result_ GUARDED_BY(mu_);
285     Notification init_done_;
286   };
287 
288   mutable mutex mu_;
289 
290   Env* const env_;
291   const DeviceMgr* const device_mgr_;
292   const FunctionLibraryDefinition* lib_def_;
293   thread::ThreadPool* default_thread_pool_;
294 
295   // Holds all the function instantiations. Maps function_keys to handles.
296   std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
297       GUARDED_BY(mu_);
298 
299   // Function data for instantitated remote functions.
300   std::unordered_map<FunctionLibraryRuntime::Handle,
301                      std::unique_ptr<FunctionData>>
302       function_data_ GUARDED_BY(mu_);
303 
304   // Function data for instantiated multi-device functions.
305   std::unordered_map<FunctionLibraryRuntime::Handle,
306                      std::unique_ptr<MultiDeviceFunctionData>>
307       mdevice_data_ GUARDED_BY(mu_);
308 
309   std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_;
310   int next_handle_ GUARDED_BY(mu_);
311   DistributedFunctionLibraryRuntime* const parent_;
312 };
313 
314 }  // namespace tensorflow
315 
316 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
317