1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ 17 18 #include "absl/types/optional.h" 19 #include "tensorflow/core/common_runtime/device_mgr.h" 20 #include "tensorflow/core/common_runtime/eager/context.h" 21 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 23 #include "tensorflow/core/distributed_runtime/worker_session.h" 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 26 27 namespace tensorflow { 28 29 class WorkerSession; 30 31 namespace eager { 32 33 // EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run 34 // functions across processes by making RPCs through eager service. 35 class EagerClusterFunctionLibraryRuntime 36 : public DistributedFunctionLibraryRuntime { 37 public: EagerClusterFunctionLibraryRuntime(const uint64 context_id,EagerContext * ctx,DeviceMgr * remote_device_mgr)38 EagerClusterFunctionLibraryRuntime(const uint64 context_id, EagerContext* ctx, 39 DeviceMgr* remote_device_mgr) 40 : context_id_(context_id), 41 ctx_(ctx), 42 remote_device_mgr_(remote_device_mgr) {} 43 ~EagerClusterFunctionLibraryRuntime()44 ~EagerClusterFunctionLibraryRuntime() override{}; 45 46 // Register a partition (i.e., component function) of a multi-device function 47 // on the remote target specified in `options.target`. This should be 48 // triggered as part of instantiating a multi-device function in 49 // ProcessFunctionLibraryRuntime. 50 void Instantiate(const string& function_name, 51 const FunctionLibraryDefinition& lib_def, AttrSlice attrs, 52 const FunctionLibraryRuntime::InstantiateOptions& options, 53 FunctionLibraryRuntime::LocalHandle* handle, 54 FunctionLibraryRuntime::DoneCallback done) override; 55 56 // Execute the component function specified by `handle` on its instantiated 57 // remote target. This should be triggered as part of driving a multi-device 58 // function execution in ProcessFunctionLibraryRuntime. Running the component 59 // function remotely is purely asynchronous, and multiple component functions 60 // with the same remote target are not executed in any particular ordering. 61 // The main function side must wait for all component functions to finish 62 // (i.e., the done callbacks triggered) before finishing its execution. 63 void Run(const FunctionLibraryRuntime::Options& opts, 64 FunctionLibraryRuntime::LocalHandle handle, 65 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, 66 FunctionLibraryRuntime::DoneCallback done) override; 67 68 // The component function inputs `args` and outputs `rets` may refer to remote 69 // tensors on a remote device, which will be lazily resolved remotely where 70 // the inputs/outputs are actually consumed. 71 void Run(const FunctionLibraryRuntime::Options& opts, 72 FunctionLibraryRuntime::LocalHandle handle, 73 gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets, 74 FunctionLibraryRuntime::DoneCallback done) override; 75 76 void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, 77 FunctionLibraryRuntime::DoneCallback done) override; 78 remote_device_mgr()79 DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; } 80 81 private: 82 const uint64 context_id_; 83 EagerContext* ctx_; 84 DeviceMgr* remote_device_mgr_; // not owned. 85 86 struct FunctionData { 87 const string target; 88 const absl::optional<std::vector<int>> ret_indices; 89 core::RefCountPtr<EagerClient> eager_client; 90 std::unique_ptr<EagerOperation> op; 91 FunctionDataFunctionData92 FunctionData(const string& target, 93 const absl::optional<std::vector<int>>& ret_indices, 94 EagerClient* eager_client, std::unique_ptr<EagerOperation> op) 95 : target(target), 96 ret_indices(ret_indices), 97 eager_client(core::RefCountPtr<EagerClient>(eager_client)), 98 op(std::move(op)) { 99 eager_client->Ref(); 100 } 101 }; 102 103 mutable mutex mu_; 104 std::vector<FunctionData> function_data_ TF_GUARDED_BY(mu_); 105 }; 106 107 DistributedFunctionLibraryRuntime* CreateClusterFLR( 108 const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session); 109 110 } // namespace eager 111 } // namespace tensorflow 112 113 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ 114