1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
16
17 #include <map>
18 #include <memory>
19
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/distributed_runtime/call_options.h"
24 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
25 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
26 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph_def_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/cleanup.h"
32
33 namespace tensorflow {
34 namespace eager {
35 namespace {
StripDefaultAttributesInRegisterFunctionOp(RegisterFunctionOp * register_function)36 void StripDefaultAttributesInRegisterFunctionOp(
37 RegisterFunctionOp* register_function) {
38 StripDefaultAttributes(
39 *OpRegistry::Global(),
40 register_function->mutable_function_def()->mutable_node_def());
41 for (auto& function :
42 *register_function->mutable_library()->mutable_function()) {
43 StripDefaultAttributes(*OpRegistry::Global(), function.mutable_node_def());
44 }
45 }
46 } // namespace
47
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)48 void EagerClusterFunctionLibraryRuntime::Instantiate(
49 const string& function_name, const FunctionLibraryDefinition& lib_def,
50 AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
51 FunctionLibraryRuntime::LocalHandle* handle,
52 FunctionLibraryRuntime::DoneCallback done) {
53 auto target = options.target;
54 auto released_op = std::make_unique<EagerOperation>(ctx_);
55 Status s =
56 released_op->Reset(function_name.c_str(), target.c_str(), true, nullptr);
57 if (!s.ok()) {
58 done(s);
59 return;
60 }
61 if (!released_op->is_function()) {
62 done(errors::Internal(function_name, " is not a function."));
63 return;
64 }
65
66 VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
67 << " (this: " << this << ")";
68 core::RefCountPtr<eager::EagerClient> eager_client;
69 s = ctx_->GetClient(target, &eager_client);
70 if (!s.ok()) {
71 done(s);
72 return;
73 }
74
75 if (eager_client == nullptr) {
76 done(errors::InvalidArgument("Could not find eager client for target: ",
77 target));
78 return;
79 }
80
81 const FunctionLibraryDefinition& func_lib_def =
82 options.lib_def ? *options.lib_def : lib_def;
83
84 auto request = std::make_shared<EnqueueRequest>();
85 auto response = std::make_shared<EnqueueResponse>();
86
87 request->set_context_id(context_id_);
88
89 RegisterFunctionOp* register_function =
90 request->add_queue()->mutable_register_function();
91 *register_function->mutable_function_def() =
92 *func_lib_def.Find(function_name);
93 register_function->set_is_component_function(true);
94 *register_function->mutable_library() =
95 func_lib_def.ReachableDefinitions(register_function->function_def())
96 .ToProto();
97 StripDefaultAttributesInRegisterFunctionOp(register_function);
98
99 const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
100 eager_client->EnqueueAsync(
101 /*call_opts=*/nullptr, request.get(), response.get(),
102 [this, request, response, handle, released_op = released_op.release(),
103 target, ret_indices, eager_client = eager_client.get(),
104 done](const Status& s) {
105 {
106 mutex_lock l(mu_);
107 *handle = function_data_.size();
108 function_data_.emplace_back(target, ret_indices, eager_client,
109 absl::WrapUnique(released_op));
110 }
111 done(s);
112 });
113 }
114
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)115 void EagerClusterFunctionLibraryRuntime::Run(
116 const FunctionLibraryRuntime::Options& opts,
117 FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
118 std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
119 std::vector<FunctionArg> function_args;
120 for (const auto& tensor : args) {
121 function_args.push_back(tensor);
122 }
123 std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
124 Run(opts, handle, function_args, function_rets,
125 [rets, function_rets, done = std::move(done)](const Status& s) {
126 Status status = s;
127 if (status.ok()) {
128 for (const auto& t : *function_rets) {
129 if (t.index() == 0) {
130 rets->push_back(absl::get<Tensor>(t));
131 } else {
132 status.Update(
133 errors::Internal("Expect a Tensor as a remote function "
134 "output but got a TensorShape."));
135 break;
136 }
137 }
138 }
139 delete function_rets;
140 done(status);
141 });
142 }
143
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)144 void EagerClusterFunctionLibraryRuntime::Run(
145 const FunctionLibraryRuntime::Options& opts,
146 FunctionLibraryRuntime::LocalHandle handle,
147 gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
148 FunctionLibraryRuntime::DoneCallback done) {
149 FunctionData* function_data = nullptr;
150 {
151 mutex_lock l(mu_);
152 DCHECK_LE(handle, function_data_.size());
153 function_data = &function_data_[handle];
154 }
155
156 EagerClient* eager_client = function_data->eager_client.get();
157 if (eager_client == nullptr) {
158 done(errors::Internal("Could not find eager client"));
159 return;
160 }
161
162 EagerOperation* op = function_data->op.get();
163 if (!op->Inputs().empty()) {
164 done(errors::Internal("Inputs should not be set during instantiation."));
165 return;
166 }
167
168 auto request = std::make_shared<RunComponentFunctionRequest>();
169 auto response = std::make_shared<RunComponentFunctionResponse>();
170 request->set_context_id(context_id_);
171 eager::Operation* remote_op = request->mutable_operation();
172
173 if (function_data->ret_indices.has_value()) {
174 for (const int ret_index : function_data->ret_indices.value()) {
175 request->add_output_num(ret_index);
176 }
177 }
178
179 for (const auto& arg : args) {
180 if (arg.index() == 0) {
181 absl::get<Tensor>(arg).AsProtoTensorContent(
182 remote_op->add_op_inputs()->mutable_tensor());
183 } else {
184 remote_op->add_op_inputs()->mutable_remote_handle()->Swap(
185 absl::get<RemoteTensorHandle*>(arg));
186 }
187 }
188
189 // The remote component function should use the same op_id as its parent
190 // multi-device function's in order to get the global unique op_id generated
191 // by the master context.
192 if (opts.op_id.has_value()) {
193 remote_op->set_id(opts.op_id.value());
194 } else {
195 remote_op->set_id(kInvalidRemoteOpId);
196 }
197 remote_op->set_is_function(true);
198 remote_op->set_is_component_function(true);
199 remote_op->set_func_step_id(opts.step_id);
200 remote_op->set_name(op->Name());
201 op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
202 remote_op->set_device(function_data->target);
203
204 CancellationManager* cm = opts.cancellation_manager;
205 CancellationToken token = 0;
206 auto call_opts = std::make_shared<CallOptions>();
207 if (cm != nullptr) {
208 token = cm->get_cancellation_token();
209 const bool already_cancelled = !cm->RegisterCallback(
210 token,
211 [call_opts, request, response, done]() { call_opts->StartCancel(); });
212 if (already_cancelled) {
213 done(errors::Cancelled("EagerClusterFunctionLibraryRuntime::Run"));
214 return;
215 }
216 }
217
218 // Execute component function on remote worker using RunComponentFunction RPC.
219 // Different from executing remote functions with Enqueue, this method runs
220 // a function on remote worker without tying up a thread (i.e., pure
221 // asynchronously).
222 eager_client->RunComponentFunctionAsync(
223 call_opts.get(), request.get(), response.get(),
224 [request, response, rets, call_opts, cm, token,
225 done = std::move(done)](const Status& s) {
226 if (cm != nullptr) {
227 cm->TryDeregisterCallback(token);
228 }
229 if (!s.ok()) {
230 done(s);
231 return;
232 }
233 if (!response->shape().empty() && !response->tensor().empty()) {
234 done(errors::Internal(
235 "Both shape and tensor are specified in the same response"));
236 return;
237 }
238 for (const auto& shape : response->shape()) {
239 rets->push_back(shape);
240 }
241 for (const auto& tensor_proto : response->tensor()) {
242 Tensor t;
243 if (t.FromProto(tensor_proto)) {
244 rets->push_back(std::move(t));
245 } else {
246 done(errors::Internal("Could not convert tensor proto: ",
247 tensor_proto.DebugString()));
248 return;
249 }
250 }
251 done(Status::OK());
252 });
253 }
254
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)255 void EagerClusterFunctionLibraryRuntime::CleanUp(
256 uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
257 FunctionLibraryRuntime::DoneCallback done) {
258 FunctionData* function_data = nullptr;
259 {
260 mutex_lock l(mu_);
261 DCHECK_LE(handle, function_data_.size());
262 function_data = &function_data_[handle];
263 }
264
265 EagerClient* eager_client = function_data->eager_client.get();
266 if (eager_client == nullptr) {
267 done(errors::Internal("Could not find eager client"));
268 return;
269 }
270
271 auto request = std::make_shared<EnqueueRequest>();
272 auto response = std::make_shared<EnqueueResponse>();
273 request->set_context_id(context_id_);
274 CleanupFunctionOp* cleanup_function =
275 request->add_queue()->mutable_cleanup_function();
276 cleanup_function->set_step_id(step_id);
277 // StreamingEnqueueAsync could be blocking when streaming RPC is disabled.
278 // CleanUp() needs to be non-blocking since it would be invoked inside the
279 // enqueue done callback of Run(). So we don't use StreamingEnqueueAsync here.
280 eager_client->EnqueueAsync(
281 /*call_opts=*/nullptr, request.get(), response.get(),
282 [request, response, done](const Status& status) { done(status); });
283 }
284
CreateClusterFLR(const uint64 context_id,EagerContext * ctx,WorkerSession * worker_session)285 DistributedFunctionLibraryRuntime* CreateClusterFLR(
286 const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
287 return new EagerClusterFunctionLibraryRuntime(
288 context_id, ctx, worker_session->remote_device_mgr());
289 }
290
291 } // namespace eager
292 } // namespace tensorflow
293