• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/worker_session.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
26 
27 namespace tensorflow {
28 
29 class WorkerSession;
30 
31 namespace eager {
32 
33 // EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run
34 // functions across processes by making RPCs through eager service.
35 class EagerClusterFunctionLibraryRuntime
36     : public DistributedFunctionLibraryRuntime {
37  public:
EagerClusterFunctionLibraryRuntime(const uint64 context_id,EagerContext * ctx,DeviceMgr * remote_device_mgr)38   EagerClusterFunctionLibraryRuntime(const uint64 context_id, EagerContext* ctx,
39                                      DeviceMgr* remote_device_mgr)
40       : context_id_(context_id),
41         ctx_(ctx),
42         remote_device_mgr_(remote_device_mgr) {}
43 
~EagerClusterFunctionLibraryRuntime()44   ~EagerClusterFunctionLibraryRuntime() override{};
45 
46   // Register a partition (i.e., component function) of a multi-device function
47   // on the remote target specified in `options.target`. This should be
48   // triggered as part of instantiating a multi-device function in
49   // ProcessFunctionLibraryRuntime.
50   void Instantiate(const string& function_name,
51                    const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
52                    const FunctionLibraryRuntime::InstantiateOptions& options,
53                    FunctionLibraryRuntime::LocalHandle* handle,
54                    FunctionLibraryRuntime::DoneCallback done) override;
55 
56   // Execute the component function specified by `handle` on its instantiated
57   // remote target. This should be triggered as part of driving a multi-device
58   // function execution in ProcessFunctionLibraryRuntime. Running the component
59   // function remotely is purely asynchronous, and multiple component functions
60   // with the same remote target are not executed in any particular ordering.
61   // The main function side must wait for all component functions to finish
62   // (i.e., the done callbacks triggered) before finishing its execution.
63   void Run(const FunctionLibraryRuntime::Options& opts,
64            FunctionLibraryRuntime::LocalHandle handle,
65            gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
66            FunctionLibraryRuntime::DoneCallback done) override;
67 
68   // The component function inputs `args` and outputs `rets` may refer to remote
69   // tensors on a remote device, which will be lazily resolved remotely where
70   // the inputs/outputs are actually consumed.
71   void Run(const FunctionLibraryRuntime::Options& opts,
72            FunctionLibraryRuntime::LocalHandle handle,
73            gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
74            FunctionLibraryRuntime::DoneCallback done) override;
75 
76   void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
77                FunctionLibraryRuntime::DoneCallback done) override;
78 
remote_device_mgr()79   DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; }
80 
81  private:
82   const uint64 context_id_;
83   EagerContext* ctx_;
84   DeviceMgr* remote_device_mgr_;  // not owned.
85 
86   struct FunctionData {
87     const string target;
88     const absl::optional<std::vector<int>> ret_indices;
89     core::RefCountPtr<EagerClient> eager_client;
90     std::unique_ptr<EagerOperation> op;
91 
FunctionDataFunctionData92     FunctionData(const string& target,
93                  const absl::optional<std::vector<int>>& ret_indices,
94                  EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
95         : target(target),
96           ret_indices(ret_indices),
97           eager_client(core::RefCountPtr<EagerClient>(eager_client)),
98           op(std::move(op)) {
99       eager_client->Ref();
100     }
101   };
102 
103   mutable mutex mu_;
104   std::vector<FunctionData> function_data_ TF_GUARDED_BY(mu_);
105 };
106 
107 DistributedFunctionLibraryRuntime* CreateClusterFLR(
108     const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session);
109 
110 }  // namespace eager
111 }  // namespace tensorflow
112 
113 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
114