• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
17 
18 #include <unordered_map>
19 
20 // clang-format off
21 // Required for IS_MOBILE_PLATFORM
22 #include "tensorflow/core/platform/platform.h"
23 // clang-format on
24 
25 #include "absl/types/optional.h"
26 #include "absl/types/variant.h"
27 #include "tensorflow/core/common_runtime/composite_device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/protobuf/config.pb.h"
34 #if !defined(IS_MOBILE_PLATFORM)
35 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
36 #endif  // IS_MOBILE_PLATFORM
37 
38 namespace tensorflow {
39 
40 class FunctionArgsInterface {
41  public:
~FunctionArgsInterface()42   virtual ~FunctionArgsInterface() {}
43 
44   virtual bool HasRemoteOrPackedInputs() const = 0;
45 
46   virtual Status GetLocalArg(const FunctionArgIndex& index,
47                              Tensor* val) const = 0;
48 
49   virtual std::vector<Tensor> GetLocalTensors() const = 0;
50 
51 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val)52   virtual Status GetRemoteArg(const FunctionArgIndex& index,
53                               eager::RemoteTensorHandle* val) const {
54     return errors::Unimplemented(
55         "Serializing a remote argument is not implemented.");
56   }
57 #endif  // IS_MOBILE_PLATFORM
58 };
59 
60 // A class that stores all the FunctionLibraryRuntime objects, one per device.
61 class ProcessFunctionLibraryRuntime {
62  public:
63   // Creates FunctionLibraryRuntime objects for each device in the provided
64   // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
65   // (if provided) outlive this object.
66   ProcessFunctionLibraryRuntime(
67       const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
68       int graph_def_version, const FunctionLibraryDefinition* lib_def,
69       const OptimizerOptions& optimizer_options,
70       thread::ThreadPool* thread_pool = nullptr,
71       DistributedFunctionLibraryRuntime* parent = nullptr,
72       const SessionMetadata* session_metadata = nullptr,
73       Rendezvous::Factory rendezvous_factory = Rendezvous::Factory());
74 
~ProcessFunctionLibraryRuntime()75   ~ProcessFunctionLibraryRuntime() {
76     // Deleting the FunctionLibraryRuntime map will delete the function handles
77     // registered in it, which may call ReleaseHandle in this class again to
78     // release their sub-function. These circular calls may cause segfault
79     // since the flr_map_ may have already been deleted. Explicitly releasing
80     // flr_map_ here and checking flr_map_ in ReleaseHandle to avoid this.
81     flr_map_.reset();
82   }
83 
84   // Sends `tensors_to_send` from `source_device` to `target_device` using
85   // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
86   // Rendezvous. `device_context` should be the DeviceContext of the device
87   // doing the sending. `alloc_attrs` should either be empty or be the size of
88   // `tensors_to_send` and indicates how the input tensors are allocated. Method
89   // takes references on each of the `tensors_to_send`. Method doesn't block.
90   static Status SendTensors(const string& source_device,
91                             const string& target_device,
92                             const string& key_prefix, int64_t src_incarnation,
93                             gtl::ArraySlice<Tensor> tensors_to_send,
94                             DeviceContext* device_context,
95                             const std::vector<AllocatorAttributes>& alloc_attrs,
96                             RendezvousInterface* rendezvous);
97 
98   // Receives `received_tensors` from `target_device` (originally sent from
99   // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
100   // keys to be retrieved. `device_context` should be for the device receiving
101   // the tensors. `alloc_attrs` indicates how to allocate the received
102   // tensors and should either be empty or `num_tensors` in size. Method doesn't
103   // block and calls `done` when `num_tensors` are fetched.
104   static void ReceiveTensorsAsync(
105       const string& source_device, const string& target_device,
106       const string& key_prefix, int64_t src_incarnation, int64_t num_tensors,
107       DeviceContext* device_context,
108       const std::vector<AllocatorAttributes>& alloc_attrs,
109       RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
110       StatusCallback done);
111 
112   static const char kDefaultFLRDevice[];
113   // Returns the FunctionLibraryRuntime for the corresponding device_name.
114   FunctionLibraryRuntime* GetFLR(const string& device_name) const;
115 
116   // Returns the return types for the function identified by handle `h`.
117   Status GetRetTypes(FunctionLibraryRuntime::Handle h,
118                      DataTypeVector* ret_types);
119 
120   // Returns the device incarnation for the given device_name.
121   Status GetDeviceIncarnation(const string& device_name,
122                               int64* incarnation) const;
123 
124   // For a given canonicalized key signature of the function instantiated
125   // on device `device_name` and a `local_handle`, creates a handle and returns
126   // that value. Uses core/common_runtime/framework/function.h::Canonicalize
127   // to canonicalize the function signature.
128   FunctionLibraryRuntime::Handle AddHandle(
129       const string& function_key, const string& device_name,
130       FunctionLibraryRuntime::LocalHandle local_handle);
131 
132   // Returns a handle if found for the given key, else returns kInvalidHandle.
133   FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const;
134 
135   // For the given handle instantiated on device `device_name` returns the local
136   // index of instantiation of that function. If the function was not
137   // instantiated on `device_name` or the function is multi-device,
138   // returns kInvalidLocalHandle.
139   //
140   // If `include_multi_device` is true and `handle` is a multi-device function
141   // with a single component that is placed on `device_name`, then this method
142   // will return the local handle for that component.
143   FunctionLibraryRuntime::LocalHandle GetHandleOnDevice(
144       const string& device_name, FunctionLibraryRuntime::Handle handle,
145       bool include_multi_device = false) const;
146 
147   // Fills `output_devices` with the devices on which the results will
148   // be produced. If some output is produced on CPU, the corresponding Device*
149   // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device*
150   // is set to the device backing the resource.
151   // REQUIRES: `handle` identifies a multi-device function.
152   Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
153                           std::vector<Device*>* output_devices) const;
154 
155   // Returns true if function with handle `handle` was instantiated on device
156   // `device_name`. Returns false for multi-device functions.
157   bool IsInstantiatedOnDevice(const string& device_name,
158                               FunctionLibraryRuntime::Handle handle) const;
159 
160   // Instantiates the function. See framework/function.h for more details.
161   // Allows for function_name to be instantiated on different devices
162   // as specified in attrs.
163   Status Instantiate(const string& function_name, AttrSlice attrs,
164                      const FunctionLibraryRuntime::InstantiateOptions& options,
165                      FunctionLibraryRuntime::Handle* handle);
166 
167   // Returns whether the function represented by the given handle needs to
168   // execute cross process.
169   Status IsCrossProcess(FunctionLibraryRuntime::Handle handle,
170                         bool* is_cross_process) const;
171 
172   // TODO(iga): Reword
173   // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
174   // corresponding resource lives. This ensures that the Placer assigns ops that
175   // access these resources to the appropriate devices.
176   static Status PinArgsAndRets(const std::vector<string>& input_devices,
177                                const std::vector<string>& output_devices,
178                                const DeviceSet& device_set,
179                                const std::vector<Node*>& arg_nodes,
180                                const std::vector<Node*>& ret_nodes,
181                                const FunctionLibraryDefinition* lib_def,
182                                Device* default_device);
183 
184   // Delegates to the local FLR that owns state corresponding to `handle` and
185   // tells it to release it. If the `handle` isn't needed at all, the local FLR
186   // might call RemoveHandle on this to get rid of the state owned by the Proc
187   // FLR.
188   // For multi-device functions, calls ReleaseHandle on local FLRs for each
189   // component function that is part of this multi-device function.
190   // Each local FLR might call RemoveHandle on this.
191   Status ReleaseHandle(FunctionLibraryRuntime::Handle handle);
192 
193   // Runs the function with given `handle`. Function could have been
194   // instantiated on any device. More details in framework/function.h
195   void Run(const FunctionLibraryRuntime::Options& opts,
196            FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
197            std::vector<Tensor>* rets,
198            FunctionLibraryRuntime::DoneCallback done) const;
199   void Run(const FunctionLibraryRuntime::Options& opts,
200            FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
201            FunctionLibraryRuntime::DoneCallback done) const;
202 
203   void Run(const FunctionLibraryRuntime::Options& opts,
204            FunctionLibraryRuntime::Handle handle,
205            const FunctionArgsInterface& args, std::vector<FunctionRet>* rets,
206            FunctionLibraryRuntime::DoneCallback done) const;
207 
208   Status RunSync(const FunctionLibraryRuntime::Options& opts,
209                  FunctionLibraryRuntime::Handle handle,
210                  gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const;
211   Status RunSync(const FunctionLibraryRuntime::Options& opts,
212                  FunctionLibraryRuntime::Handle handle,
213                  CallFrameInterface* frame) const;
214 
device_mgr()215   const DeviceMgr* device_mgr() { return device_mgr_; }
216 
device_set()217   const std::shared_ptr<DeviceSet> device_set() const {
218     tf_shared_lock l(mu_);
219     return device_set_;
220   }
221 
222   // Initialize the set of local and remote devices and corresponding flr for op
223   // device selection.
224   void InitializeDeviceAndFlr();
225 
config()226   const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
227 
GetFunctionLibraryDefinition()228   const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const {
229     return lib_def_;
230   }
231 
232   // Add a CompositeDevice to `device_set_`
AddCompositeDevice(CompositeDevice * d)233   void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
234     mutex_lock l(mu_);
235     device_set_->AddDevice(d);
236     composite_devices_.push_back(d);
237   }
238 
239  protected:
240   friend class FunctionLibraryRuntimeImpl;
241 
242   struct InternalArgs {
243     std::vector<FunctionArg> args;
244 #if !defined(IS_MOBILE_PLATFORM)
245     // Holds the RemoteTensorHandles referred by args.
246     std::vector<std::unique_ptr<eager::RemoteTensorHandle>> remote_args;
247 #endif  // IS_MOBILE_PLATFORM
248   };
249 
250   // Structure to keep track of how a component function (a single-device
251   // piece of a multi-device function) fits into the multi-device function.
252   struct ComponentFunctionData {
253     // The handle for the instantiated component function.
254     FunctionLibraryRuntime::Handle handle;
255     // arg_indices.size() is the number of arguments to the component function.
256     // The i-th argument of the component function comes from the
257     // `arg_indices[i]`-th argument of the multi-device function.
258     std::vector<FunctionArgIndex> arg_indices;
259     // ret_indices.size() is the number of return values of the component
260     // function.  The i-th return value of the component function goes to the
261     // `ret_indices[i]`-th return value of the multi-device function.
262     std::vector<int> ret_indices;
263     // arg_alloc_attrs[i] are the allocator attributes of the i-th argument to
264     // the component function.
265     std::vector<AllocatorAttributes> arg_alloc_attrs;
266     // ret_alloc_attrs[i] are the allocator attributes of the i-th return value
267     // of the component function.
268     std::vector<AllocatorAttributes> ret_alloc_attrs;
269   };
270 
271   // Data structure holding information for a single instantiated multi-device
272   // function.
273   // The fields are filled in during instantiation. Once the object is
274   // added to mdevice_data_, all fields are constant.
275   struct MultiDeviceFunctionData {
MultiDeviceFunctionDataMultiDeviceFunctionData276     MultiDeviceFunctionData(const string& function_name,
277                             const string& function_key, int num_outputs,
278                             FunctionLibraryDefinition&& lib_def,
279                             DataTypeVector ret_types)
280         : function_name_(function_name),
281           function_key_(function_key),
282           instantiation_counter_(1),
283           lib_def_(std::move(lib_def)),
284           num_outputs_(num_outputs),
285           ret_types_(std::move(ret_types)),
286           is_cross_process_(false),
287           has_remote_outputs(false) {}
288 
289     const string function_name_;
290     const string function_key_;
291     uint64 instantiation_counter_;
292     // A library that contains definitions of component functions and their
293     // transitive dependencies.
294     FunctionLibraryDefinition lib_def_;
295     // Stored here to resize the output tensor vector when function is run.
296     const int num_outputs_;
297     DataTypeVector ret_types_;
298 
299     // Indicates whether this function needs to execute cross process.
300     bool is_cross_process_;
301     // Indicates whether this function has remote outputs.
302     bool has_remote_outputs;
303 
304     // Maps the device name to the information about the component function
305     // be run on this device.
306     std::unordered_map<string, ComponentFunctionData> glue_;
307   };
308 
309   struct CleanUpItem {
310     string device;
311     uint64 step_id;
312     FunctionLibraryRuntime::Handle local_handle;
313   };
314 
315   // If `handle` represents a multi-device function, returns the multi-device
316   // data associated with `handle`. Else, nullptr.
317   MultiDeviceFunctionData* IsMultiDevice(
318       FunctionLibraryRuntime::Handle handle) const;
319 
320   void RunMultiDevice(
321       const FunctionLibraryRuntime::Options& opts,
322       FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets,
323       std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
324       FunctionLibraryRuntime::DoneCallback done,
325       std::function<Status(const ComponentFunctionData& comp_data,
326                            InternalArgs* args)>
327           get_component_args) const;
328 
329   Status CreateRendezvous(const FunctionLibraryRuntime::Options& opts,
330                           Rendezvous** created_rendezvous) const;
331 
332   FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
333       std::vector<std::unique_ptr<CleanUpItem>>* items,
334       FunctionLibraryRuntime::DoneCallback done, const int64_t step_id,
335       const Rendezvous* rendezvous) const;
336 
337   DistributedFunctionLibraryRuntime* const parent_;
338 
339  private:
340   FunctionLibraryRuntime::Handle AddHandleLocked(
341       const string& function_key, const string& device_name,
342       FunctionLibraryRuntime::LocalHandle local_handle)
343       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
344 
345   // For a given device_name, returns a DeviceContext for copying
346   // tensors to/from the device.
347   Status GetDeviceContext(const string& device_name,
348                           DeviceContext** device_context) const;
349 
350   // Looks up the information for the given `handle` and returns the name
351   // of the device where the function is registered.
352   string GetDeviceName(FunctionLibraryRuntime::Handle handle) const;
353 
354   // Removes handle from the state owned by this object.
355   Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
356 
357   // Clones ProcessFunctionLibraryRuntime and FunctionLibraryDefinition
358   // (transferring ownership of both to the caller). Note that the
359   // ProcessFunctionLibraryRuntime borrows a pointer to the
360   // FunctionLibraryDefinition and so the FunctionLibraryDefinition should
361   // outlive the ProcessFunctionLibraryRuntime.
362   //
363   // The `skip_flib_def` argument controls whether the method should clone the
364   // FunctionLibraryDefinition (default behavior) or return an empty function
365   // library. The latter is used by tf.data, which manages
366   // FunctionLibraryDefinitions for its functions independently (and passes
367   // these into the FunctionLibraryRuntime through an overlay), to avoid linear
368   // runtime w.r.t. to number of functions in the current function library.
369   Status Clone(Env* env, int graph_def_version,
370                const OptimizerOptions& optimizer_options,
371                std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
372                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
373                bool skip_flib_def = false) const;
374 
375   Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle);
376 
377   Status InstantiateMultiDevice(
378       const string& function_name, AttrSlice attrs,
379       const FunctionLibraryRuntime::InstantiateOptions& options,
380       FunctionLibraryRuntime::Handle* handle);
381 
382   void InstantiateRemote(
383       const string& function_name, AttrSlice attrs,
384       const FunctionLibraryRuntime::InstantiateOptions& options,
385       FunctionLibraryRuntime::Handle* handle,
386       FunctionLibraryRuntime::DoneCallback done);
387 
388   FunctionLibraryRuntime::Handle AddMultiDeviceHandle(
389       const std::unique_ptr<MultiDeviceFunctionData> data,
390       const string& function_key);
391 
392   void RunInternal(const FunctionLibraryRuntime::Options& opts,
393                    FunctionLibraryRuntime::Handle handle,
394                    gtl::ArraySlice<FunctionArg> args,
395                    std::vector<FunctionRet>* rets,
396                    std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
397                    FunctionLibraryRuntime::DoneCallback done) const;
398 
399   void CleanUp(std::vector<std::unique_ptr<CleanUpItem>>* items,
400                FunctionLibraryRuntime::DoneCallback done) const;
401 
402   // Data structure holding information for a single instantiated remote
403   // (to be executed on `target_device`) function.
404   class FunctionData {
405    public:
FunctionData(const string & target_device,FunctionLibraryRuntime::LocalHandle local_handle,const string & function_key)406     FunctionData(const string& target_device,
407                  FunctionLibraryRuntime::LocalHandle local_handle,
408                  const string& function_key)
409         : target_device_(target_device),
410           local_handle_(local_handle),
411           function_key_(function_key) {}
412 
target_device()413     const string& target_device() { return target_device_; }
function_key()414     const string& function_key() { return function_key_; }
415 
local_handle()416     FunctionLibraryRuntime::LocalHandle local_handle() {
417       mutex_lock l(mu_);
418       return local_handle_;
419     }
420 
421     // Initializes the FunctionData object by potentially making an Initialize
422     // call to the DistributedFunctionLibraryRuntime.
423     void DistributedInit(
424         DistributedFunctionLibraryRuntime* parent, const string& function_name,
425         const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
426         const FunctionLibraryRuntime::InstantiateOptions& options,
427         FunctionLibraryRuntime::DoneCallback done);
428 
is_cross_process()429     bool is_cross_process() {
430       mutex_lock l(mu_);
431       return is_cross_process_;
432     }
433 
434    private:
435     mutex mu_;
436 
437     const string target_device_;
438     FunctionLibraryRuntime::LocalHandle local_handle_ TF_GUARDED_BY(mu_);
439     const string function_key_;
440     bool is_cross_process_ TF_GUARDED_BY(mu_) = false;
441     bool init_started_ TF_GUARDED_BY(mu_) = false;
442     Status init_result_ TF_GUARDED_BY(mu_);
443     Notification init_done_;
444   };
445 
446   mutable mutex mu_;
447 
448   Env* const env_;
449   const absl::optional<const ConfigProto> config_;
450   const DeviceMgr* const device_mgr_;
451   const FunctionLibraryDefinition* lib_def_;
452   thread::ThreadPool* default_thread_pool_;
453 
454   // Cluster update can reinitialize the device_set_ due to remote device
455   // changes. At the same time, InstantiateMultiDevice can use the cached
456   // devices to instantiate multi-worker functions. Function instantiation would
457   // fail if it spans the changed remote devices.
458   std::shared_ptr<DeviceSet> device_set_ TF_GUARDED_BY(mu_);
459 
460   // Composite devices owned by a EagerContext.
461   std::vector<CompositeDevice*> composite_devices_ TF_GUARDED_BY(mu_);
462 
463   // Holds all the function instantiations. Maps function_keys to handles.
464   std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
465       TF_GUARDED_BY(mu_);
466 
467   // Function data for instantiated remote functions.
468   std::unordered_map<FunctionLibraryRuntime::Handle,
469                      std::unique_ptr<FunctionData>>
470       function_data_ TF_GUARDED_BY(mu_);
471 
472   // Function data for instantiated multi-device functions.
473   std::unordered_map<FunctionLibraryRuntime::Handle,
474                      std::unique_ptr<MultiDeviceFunctionData>>
475       mdevice_data_ TF_GUARDED_BY(mu_);
476 
477   std::unique_ptr<
478       std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>>>
479       flr_map_;
480   int next_handle_ TF_GUARDED_BY(mu_);
481   const SessionMetadata* const session_metadata_;
482   const Rendezvous::Factory rendezvous_factory_;
483 
484   const OptimizerOptions optimizer_options_;
485   const int graph_def_version_;
486 };
487 
488 }  // namespace tensorflow
489 
490 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
491