• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/protobuf/config.pb.h"
23 
24 namespace tensorflow {
25 
26 // A class that stores all the FunctionLibraryRuntime objects, one per device.
27 class ProcessFunctionLibraryRuntime {
28  public:
29   // Creates FunctionLibraryRuntime objects for each device in the provided
30   // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
31   // (if provided) outlive this object.
32   ProcessFunctionLibraryRuntime(
33       const DeviceMgr* device_mgr, Env* env, int graph_def_version,
34       const FunctionLibraryDefinition* lib_def,
35       const OptimizerOptions& optimizer_options,
36       DistributedFunctionLibraryRuntime* parent = nullptr);
37 
38   // With `custom_kernel_creator`.
39   ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env,
40                                 int graph_def_version,
41                                 const FunctionLibraryDefinition* lib_def,
42                                 const OptimizerOptions& optimizer_options,
43                                 CustomKernelCreator custom_kernel_creator,
44                                 DistributedFunctionLibraryRuntime* parent);
45 
46   // Sends `tensors_to_send` from `source_device` to `target_device` using
47   // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
48   // Rendezvous. `device_context` should be the DeviceContext of the device
49   // doing the sending. `alloc_attrs` should either be empty or be the size of
50   // `tensors_to_send` and indicates how the input tensors are allocated. Method
51   // takes references on each of the `tensors_to_send`. Method doesn't block.
52   static Status SendTensors(const string& source_device,
53                             const string& target_device,
54                             const string& key_prefix, int64 src_incarnation,
55                             gtl::ArraySlice<Tensor> tensors_to_send,
56                             DeviceContext* device_context,
57                             const std::vector<AllocatorAttributes>& alloc_attrs,
58                             Rendezvous* rendezvous);
59 
60   typedef std::function<void(const Status&)> StatusCallback;
61 
62   // Receives `received_tensors` from `target_device` (originally sent from
63   // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
64   // keys to be retrieved. `device_context` should be for the device receiving
65   // the tensors. `alloc_attrs` indicates how to allocate the received
66   // tensors and should either be empty or `num_tensors` in size. Method doesn't
67   // block and calls `done` when `num_tensors` are fetched.
68   static void ReceiveTensorsAsync(
69       const string& source_device, const string& target_device,
70       const string& key_prefix, int64 src_incarnation, int64 num_tensors,
71       DeviceContext* device_context,
72       const std::vector<AllocatorAttributes>& alloc_attrs,
73       Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
74       const StatusCallback& done);
75 
76   static const char kDefaultFLRDevice[];
77   // Returns the FunctionLibraryRuntime for the corresponding device_name.
78   FunctionLibraryRuntime* GetFLR(const string& device_name) const;
79 
80   // Returns the device incarnation for the given device_name.
81   Status GetDeviceIncarnation(const string& device_name, int64* incarnation);
82 
83   // For a given canonicalized key signature of the function instantiated
84   // on device `device_name` and a `local_handle`, creates a handle and returns
85   // that value. Uses core/common_runtime/framework/function.h::Canonicalize
86   // to canonicalize the function signature.
87   FunctionLibraryRuntime::Handle AddHandle(
88       const string& function_key, const string& device_name,
89       FunctionLibraryRuntime::LocalHandle local_handle);
90 
91   // Returns a handle if found for the given key, else returns kInvalidHandle.
92   FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const;
93 
94   // For the given handle instantiated on device `device_name` returns the local
95   // index of instantiation of that function. If the function was not
96   // instantiated on `device_name` returns kInvalidLocalHandle.
97   FunctionLibraryRuntime::LocalHandle GetHandleOnDevice(
98       const string& device_name, FunctionLibraryRuntime::Handle handle);
99 
100   // Returns true if function with handle `handle` was instantiated on device
101   // `device_name`.
102   bool IsInstantiatedOnDevice(const string& device_name,
103                               FunctionLibraryRuntime::Handle handle);
104 
105   // Instantiates the function. See framework/function.h for more details.
106   // Allows for function_name to be instantiated on different devices
107   // as specified in attrs.
108   Status Instantiate(const string& function_name, AttrSlice attrs,
109                      const FunctionLibraryRuntime::InstantiateOptions& options,
110                      FunctionLibraryRuntime::Handle* handle);
111 
112   // Delegates to the local FLR that owns state corresponding to `handle` and
113   // tells it to release it. If the `handle` isnt' needed at all, the local FLR
114   // might call RemoveHandle on this to get rid of the state owned by the Proc
115   // FLR.
116   Status ReleaseHandle(FunctionLibraryRuntime::Handle handle);
117 
118   // Runs the function with given `handle`. Function could have been
119   // instantiated on any device. More details in framework/function.h
120   void Run(const FunctionLibraryRuntime::Options& opts,
121            FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
122            std::vector<Tensor>* rets,
123            FunctionLibraryRuntime::DoneCallback done);
124 
125  private:
126   // For a given device_name, returns a DeviceContext for copying
127   // tensors to/from the device.
128   Status GetDeviceContext(const string& device_name,
129                           DeviceContext** device_context);
130 
131   // Looks up the information for the given `handle` and returns the name
132   // of the device where the function is registered.
133   string GetDeviceName(FunctionLibraryRuntime::Handle handle);
134 
135   // Removes handle from the state owned by this object.
136   Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
137 
138   Status Clone(Env* env, int graph_def_version,
139                const OptimizerOptions& optimizer_options,
140                CustomKernelCreator custom_kernel_creator,
141                std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
142                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr);
143 
144   friend class FunctionLibraryRuntimeImpl;
145 
146   mutable mutex mu_;
147 
148   struct FunctionData {
149     const string target_device;
150     const FunctionLibraryRuntime::LocalHandle local_handle;
151 
FunctionDataFunctionData152     FunctionData(const string& target_device,
153                  FunctionLibraryRuntime::LocalHandle local_handle)
154         : target_device(target_device), local_handle(local_handle) {}
FunctionDataFunctionData155     FunctionData() : FunctionData("", -1) {}
156   };
157 
158   const DeviceMgr* const device_mgr_;
159   const FunctionLibraryDefinition* lib_def_;
160   // Holds all the function invocations here.
161   std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
162       GUARDED_BY(mu_);
163   std::unordered_map<FunctionLibraryRuntime::Handle, FunctionData>
164       function_data_ GUARDED_BY(mu_);
165   std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_;
166   int next_handle_ GUARDED_BY(mu_);
167   DistributedFunctionLibraryRuntime* const parent_;
168 };
169 
170 }  // namespace tensorflow
171 
172 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
173