1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/function.h"
17
18 #include <deque>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/executor_factory.h"
27 #include "tensorflow/core/common_runtime/gradients.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/graph_optimizer.h"
30 #include "tensorflow/core/common_runtime/inline_function_utils.h"
31 #include "tensorflow/core/common_runtime/memory_types.h"
32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/common_runtime/single_threaded_executor.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function_handle_cache.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/op.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/versions.pb.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/control_flow.h"
45 #include "tensorflow/core/graph/node_builder.h"
46 #include "tensorflow/core/graph/optimizer_cse.h"
47 #include "tensorflow/core/lib/core/threadpool.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/str_util.h"
51 #include "tensorflow/core/profiler/lib/connected_traceme.h"
52 #include "tensorflow/core/profiler/lib/traceme.h"
53 #include "tensorflow/core/protobuf/config.pb.h"
54
55 // See core/kernels/function_ops.cc for related kernels.
56
57 namespace tensorflow {
58
59 // A few string constant used throughout this module.
60 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
61 static constexpr const char* const kDeviceArgOp =
62 FunctionLibraryDefinition::kDeviceArgOp;
63 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
64 static constexpr const char* const kDeviceRetOp =
65 FunctionLibraryDefinition::kDeviceRetOp;
66 static constexpr const char* const kGradientOp =
67 FunctionLibraryDefinition::kGradientOp;
68 static constexpr const char* const kNodeLabel = "Func";
69 static constexpr const char* const kFuncAttr =
70 FunctionLibraryDefinition::kFuncAttr;
71
72 // Represents the index-th output of a node.
73 struct Endpoint {
74 Node* node;
75 int index;
76
77 // Returns the string name represents this endpoint.
nametensorflow::Endpoint78 string name() const {
79 if (index == 0) {
80 return node->name();
81 } else {
82 return strings::StrCat(node->name(), ":", index);
83 }
84 }
85
dtypetensorflow::Endpoint86 DataType dtype() const { return node->output_type(index); }
87 };
88
89 struct EndpointHash {
operator ()tensorflow::EndpointHash90 uint64 operator()(const Endpoint& x) const {
91 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
92 x.index);
93 }
94 };
95
96 struct EndpointEq {
operator ()tensorflow::EndpointEq97 bool operator()(const Endpoint& x, const Endpoint& y) const {
98 return (x.node == y.node) && (x.index == y.index);
99 }
100 };
101
102 // The following Add* routines are used to add a few graph nodes while
103 // functions are transformed.
AddArg(Graph * g,DataType dtype,int index)104 static Node* AddArg(Graph* g, DataType dtype, int index) {
105 DCHECK_LT(0, dtype);
106 DCHECK_LT(dtype, DT_FLOAT_REF);
107 NodeDef ndef;
108 ndef.set_name(g->NewName(kNodeLabel));
109 ndef.set_op(kArgOp);
110 AddNodeAttr("T", dtype, &ndef);
111 AddNodeAttr("index", index, &ndef);
112 Status s;
113 Node* ret = g->AddNode(ndef, &s);
114 TF_CHECK_OK(s);
115 return ret;
116 }
117
AddRet(Graph * g,Endpoint input,int index)118 static Node* AddRet(Graph* g, Endpoint input, int index) {
119 DCHECK_LT(0, input.dtype());
120 DCHECK_LT(input.dtype(), DT_FLOAT_REF);
121 NodeDef ndef;
122 ndef.set_name(g->NewName(kNodeLabel));
123 ndef.set_op(kRetOp);
124 ndef.add_input(input.name());
125 AddNodeAttr("T", input.dtype(), &ndef);
126 AddNodeAttr("index", index, &ndef);
127 Status s;
128 Node* ret = g->AddNode(ndef, &s);
129 TF_CHECK_OK(s);
130 g->AddEdge(input.node, input.index, ret, 0);
131 return ret;
132 }
133
134 // FunctionLibraryRuntime implementation that forwards all the function calls to
135 // the base runtime implementation, and only overrides FunctionLibraryDefinition
136 // in calls to Instantiate (if caller doesn't provide the
137 // InstantiateOptions::lib_def option).
138 //
139 // When the function library runtime (FunctionLibraryRuntimeImpl specifically)
140 // instantiates a function into a Graph object, it also creates an Executor for
141 // it. That executor has a pointer to the function library runtime instance,
142 // that is used to instantiate all nested function calls.
143 //
144 // The function library definition used to instantiate the function must be
145 // preserved in the Executor's function library runtime.
146 //
147 // IMPORTANT: This runtime is intended for use only in executors created for
148 // functions instantiated into a graph in FunctionLibraryRuntimeImpl.
149 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
150 public:
FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime * base_flr,const FunctionLibraryDefinition * lib_def)151 FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr,
152 const FunctionLibraryDefinition* lib_def)
153 : base_flr_(base_flr), lib_def_(lib_def) {}
154 ~FunctionLibraryRuntimeOverlay() override;
155
156 Status Instantiate(const string& function_name, AttrSlice attrs,
157 const InstantiateOptions& options,
158 Handle* handle) override;
159
160 Status ReleaseHandle(Handle handle) override;
161
162 const FunctionBody* GetFunctionBody(Handle h) override;
163
164 Status GetRetTypes(Handle h, DataTypeVector* ret_types) override;
165
166 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
167 std::vector<Tensor>* rets, DoneCallback done) override;
168
169 void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
170 DoneCallback done) override;
171
172 Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
173 std::vector<Tensor>* rets) override;
174
175 Status RunSync(Options opts, Handle handle,
176 CallFrameInterface* frame) override;
177
178 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
179 OpKernel** kernel) override;
180
181 bool IsStateful(const string& function_name) const override;
182
183 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
184 const override;
185
186 Env* env() override;
187 const ConfigProto* const config_proto() override;
188 Device* device() override;
189 const Device* device() const override;
190 std::function<void(std::function<void()>)>* runner() override;
191 const DeviceMgr* device_mgr() const override;
192
193 string DebugString(Handle handle) override;
194 int graph_def_version() const override;
195
196 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
197 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
198 FunctionLibraryRuntime** out_flr,
199 bool skip_flib_def = false) override;
200
201 private:
202 FunctionLibraryRuntime* base_flr_; // not owned
203 const FunctionLibraryDefinition* lib_def_; // not owned
204 };
205
206 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
207
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)208 Status FunctionLibraryRuntimeOverlay::Instantiate(
209 const string& function_name, AttrSlice attrs,
210 const InstantiateOptions& options, Handle* handle) {
211 // We automatically set the `lib_def` option for all instantiations, if the
212 // caller doesn't set this option explicitly.
213 if (!options.lib_def && lib_def_) {
214 InstantiateOptions options_copy = options;
215 options_copy.lib_def = lib_def_;
216 return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
217 } else {
218 return base_flr_->Instantiate(function_name, attrs, options, handle);
219 }
220 }
221
ReleaseHandle(Handle handle)222 Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) {
223 return base_flr_->ReleaseHandle(handle);
224 }
225
GetFunctionBody(Handle h)226 const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) {
227 return base_flr_->GetFunctionBody(h);
228 }
229
GetRetTypes(Handle h,DataTypeVector * ret_types)230 Status FunctionLibraryRuntimeOverlay::GetRetTypes(Handle h,
231 DataTypeVector* ret_types) {
232 return base_flr_->GetRetTypes(h, ret_types);
233 }
234
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)235 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
236 gtl::ArraySlice<Tensor> args,
237 std::vector<Tensor>* rets,
238 DoneCallback done) {
239 base_flr_->Run(opts, handle, args, rets, std::move(done));
240 }
241
Run(const Options & opts,Handle handle,CallFrameInterface * call_frame,DoneCallback done)242 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
243 CallFrameInterface* call_frame,
244 DoneCallback done) {
245 base_flr_->Run(opts, handle, call_frame, std::move(done));
246 }
247
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)248 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
249 gtl::ArraySlice<Tensor> args,
250 std::vector<Tensor>* rets) {
251 return base_flr_->RunSync(std::move(opts), handle, args, rets);
252 }
253
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)254 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
255 CallFrameInterface* call_frame) {
256 return base_flr_->RunSync(std::move(opts), handle, call_frame);
257 }
258
CreateKernel(const std::shared_ptr<const NodeProperties> &,OpKernel **)259 Status FunctionLibraryRuntimeOverlay::CreateKernel(
260 const std::shared_ptr<const NodeProperties>&, OpKernel**) {
261 // We don't have access to base_lib_def_ in base function library runtime (aka
262 // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with
263 // the wrong lib_def we just disable creation of new kernels through overlays.
264 //
265 // When we call Instantiate from the base runtime with the lib_def option,
266 // the base runtime implementation is responsible for correctly passing it
267 // through to all kernel constructions.
268 return errors::Internal(
269 "Overlay function library runtime doesn't support kernel creation.");
270 }
271
IsStateful(const string & function_name) const272 bool FunctionLibraryRuntimeOverlay::IsStateful(
273 const string& function_name) const {
274 // Important: we do not forward lookup to the base FLR.
275 const OpDef* op_def;
276 const Status s = lib_def_->LookUpOpDef(function_name, &op_def);
277 return s.ok() && op_def->is_stateful();
278 }
279
env()280 Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
281
config_proto()282 const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
283 return base_flr_->config_proto();
284 }
285
device()286 Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
287
device() const288 const Device* FunctionLibraryRuntimeOverlay::device() const {
289 return base_flr_->device();
290 }
291
292 std::function<void(std::function<void()>)>*
runner()293 FunctionLibraryRuntimeOverlay::runner() {
294 return base_flr_->runner();
295 }
296
device_mgr() const297 const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
298 return base_flr_->device_mgr();
299 }
300
301 const FunctionLibraryDefinition*
GetFunctionLibraryDefinition() const302 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
303 return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition();
304 }
305
DebugString(Handle handle)306 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
307 return base_flr_->DebugString(handle);
308 }
309
graph_def_version() const310 int FunctionLibraryRuntimeOverlay::graph_def_version() const {
311 return base_flr_->graph_def_version();
312 }
313
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)314 Status FunctionLibraryRuntimeOverlay::Clone(
315 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
316 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
317 FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
318 // NOTE(ezhulenev): The cloned FunctionLibraryRuntime will be missing the
319 // FunctionLibraryDefinition override, but that's ok because we anyway do not
320 // copy / clone instantiated items from the base FLR.
321 return base_flr_->Clone(out_lib_def, out_pflr, out_flr, skip_flib_def);
322 }
323
324 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
325 public:
326 FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
327 const ConfigProto* config, Device* device,
328 int graph_def_version,
329 const FunctionLibraryDefinition* lib_def,
330 thread::ThreadPool* default_thread_pool,
331 const OptimizerOptions& optimizer_options,
332 const SessionMetadata* session_metadata,
333 ProcessFunctionLibraryRuntime* parent);
334
335 ~FunctionLibraryRuntimeImpl() override;
336
337 Status Instantiate(const string& function_name, AttrSlice attrs,
338 const InstantiateOptions& options,
339 Handle* handle) override;
340
341 Status ReleaseHandle(Handle handle) override;
342
343 const FunctionBody* GetFunctionBody(Handle handle) override;
344
345 Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override;
346
347 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
348 OpKernel** kernel) override;
349
350 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
351 std::vector<Tensor>* rets, DoneCallback done) override;
352 void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
353 DoneCallback done) override;
354 Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
355 std::vector<Tensor>* rets) override;
356 Status RunSync(Options opts, Handle handle,
357 CallFrameInterface* call_frame) override;
358
359 bool IsStateful(const string& function) const override;
360
GetFunctionLibraryDefinition() const361 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
362 const override {
363 return base_lib_def_;
364 }
365
device()366 Device* device() override { return device_; }
device() const367 const Device* device() const override { return device_; }
368
runner()369 std::function<void(std::function<void()>)>* runner() override {
370 return &default_runner_;
371 }
372
device_mgr() const373 const DeviceMgr* device_mgr() const override { return device_mgr_; }
env()374 Env* env() override { return env_; }
config_proto()375 const ConfigProto* const config_proto() override { return config_; }
graph_def_version() const376 int graph_def_version() const override { return graph_def_version_; }
377
378 string DebugString(Handle h) override;
379
380 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
381 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
382 FunctionLibraryRuntime** out_flr,
383 bool skip_flib_def = false) override;
384
385 private:
386 typedef FunctionLibraryRuntimeImpl ME;
387
388 const DeviceMgr* const device_mgr_;
389 Device* const device_;
390 Env* const env_;
391 const ConfigProto* const config_;
392 const int graph_def_version_;
393 const FunctionLibraryDefinition* const base_lib_def_;
394 GraphOptimizer optimizer_;
395 const SessionMetadata* const session_metadata_;
396 Executor::Args::Runner default_runner_;
397 const string device_name_;
398
399 std::function<Status(const string&, const OpDef**)> get_func_sig_;
400 std::function<Status(const std::shared_ptr<const NodeProperties>&,
401 OpKernel**)>
402 create_kernel_;
403
404 mutable mutex mu_;
405
406 int next_handle_ TF_GUARDED_BY(mu_);
407
408 // The instantiated and transformed function is encoded as a Graph
409 // object, and an executor is created for the graph.
410 struct Item {
411 uint64 instantiation_counter = 0;
412 std::unique_ptr<const Graph> graph = nullptr;
413 const FunctionLibraryDefinition* lib_def = nullptr; // Not owned.
414 FunctionBody* func_graph = nullptr;
415 Executor* exec = nullptr;
416 FunctionLibraryRuntimeOverlay* overlay_flr = nullptr;
417 string executor_type;
418
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item419 ~Item() {
420 delete this->func_graph;
421 delete this->exec;
422 delete this->overlay_flr;
423 }
424 };
425 std::unique_ptr<absl::flat_hash_map<Handle, std::unique_ptr<Item>>> items_
426 TF_GUARDED_BY(mu_);
427 std::unique_ptr<FunctionHandleCache> function_handle_cache_;
428 ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
429
430 // Overloads the CreateKernel method, providing a FunctionLibraryRuntime
431 // to use for kernel creation and execution. In particular, this method can
432 // accept a FunctionLibraryRuntimeOverlay that overlays a different
433 // FunctionLibraryDefinition.
434 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
435 FunctionLibraryRuntime* flr, OpKernel** kernel);
436 Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
437 const FunctionLibraryDefinition* lib_def,
438 std::unique_ptr<FunctionBody>* fbody);
439 Status CreateItem(Item** item);
440 Status GetOrCreateItem(LocalHandle local_handle, Item** item);
441 Status InstantiateSymbolicGradient(const NameAttrList& func,
442 const FunctionLibraryDefinition* lib_def,
443 std::unique_ptr<FunctionBody>* g_body);
444 bool IsLocalTarget(const InstantiateOptions& options) const;
445 AttrValueMap FixAttrs(const AttrSlice& attrs);
446 void RunRemote(const Options& opts, Handle handle,
447 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
448 Item* item, DoneCallback done);
449
450 Status PrepareRunSync(
451 Handle handle, Options* run_opts, Item** out_item,
452 std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous);
453
454 void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
455 CallFrameInterface* frame,
456 Executor::Args* exec_args);
457
458 TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
459 };
460
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * default_thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)461 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
462 const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
463 int graph_def_version, const FunctionLibraryDefinition* lib_def,
464 thread::ThreadPool* default_thread_pool,
465 const OptimizerOptions& optimizer_options,
466 const SessionMetadata* session_metadata,
467 ProcessFunctionLibraryRuntime* parent)
468 : device_mgr_(dmgr),
469 device_(device),
470 env_(env),
471 config_(config),
472 graph_def_version_(graph_def_version),
473 base_lib_def_(lib_def),
474 optimizer_(optimizer_options),
475 session_metadata_(session_metadata),
476 default_runner_(nullptr),
477 device_name_(device_ == nullptr
478 ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
479 : device_->name()),
480 next_handle_(0),
481 items_(absl::make_unique<
482 absl::flat_hash_map<Handle, std::unique_ptr<Item>>>()),
483 function_handle_cache_(absl::make_unique<FunctionHandleCache>(this)),
484 parent_(parent) {
485 get_func_sig_ = [this](const string& op, const OpDef** sig) {
486 return base_lib_def_->LookUpOpDef(op, sig);
487 };
488 create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props,
489 OpKernel** kernel) {
490 return CreateKernel(props, kernel);
491 };
492 thread::ThreadPool* pool = nullptr;
493 if (device_ != nullptr) {
494 pool = device_->tensorflow_device_thread_pool();
495 }
496 if (pool == nullptr) {
497 pool = default_thread_pool;
498 }
499 if (pool != nullptr) {
500 default_runner_ = [pool](Executor::Args::Closure c) {
501 pool->Schedule(std::move(c));
502 };
503 }
504 }
505
~FunctionLibraryRuntimeImpl()506 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
507 // Deleting the items_ list will delete all the function handles registered in
508 // this object. A function may contains a few sub-functions which have also
509 // been registered in this object. Deleting the parent function will call
510 // ReleaseHandle in this class again for each of the sub-functions. These
511 // circular calls may cause segfault since the items_ may have already been
512 // partially deleted when releasing handles of sub-functions. Explicitly
513 // release items_ here and check it in ReleaseHandle to avoid this.
514 items_.reset();
515 }
516
517 // An asynchronous op kernel which executes an instantiated function
518 // defined in a library.
519 class CallOp : public AsyncOpKernel {
520 public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)521 CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
522 : AsyncOpKernel(ctx), handle_(handle) {}
523
~CallOp()524 ~CallOp() override {
525 // TODO(iga): Release the cached handle_
526 }
527
ComputeAsync(OpKernelContext * ctx,DoneCallback done)528 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
529 FunctionLibraryRuntime* lib = ctx->function_library();
530 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
531 errors::Internal("No function library is provided."),
532 done);
533 FunctionLibraryRuntime::Options opts;
534 opts.rendezvous = ctx->rendezvous();
535 opts.cancellation_manager = ctx->cancellation_manager();
536 opts.step_container = ctx->step_container();
537 opts.stats_collector = ctx->stats_collector();
538 opts.runner = ctx->runner();
539 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
540 opts.collective_executor = ctx->collective_executor();
541 std::vector<Tensor> args;
542 args.reserve(ctx->num_inputs());
543 for (int i = 0; i < ctx->num_inputs(); ++i) {
544 args.push_back(ctx->input(i));
545 }
546 std::vector<Tensor>* rets = new std::vector<Tensor>;
547 profiler::TraceMe trace_me(
548 [&] {
549 return absl::StrCat("CallOp #parent_step_id=", ctx->step_id(),
550 ",function_step_id=", opts.step_id, "#");
551 },
552 /*level=*/2);
553 lib->Run(opts, handle_, args, rets,
554 [ctx, done, rets](const Status& status) {
555 if (!status.ok()) {
556 ctx->SetStatus(status);
557 } else {
558 const int ret_size = static_cast<int>(rets->size());
559 CHECK_EQ(ret_size, ctx->num_outputs());
560 for (int i = 0; i < ret_size; ++i) {
561 ctx->set_output(i, (*rets)[i]);
562 }
563 }
564 delete rets;
565 done();
566 });
567 }
568
569 private:
570 FunctionLibraryRuntime::Handle handle_;
571
572 TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
573 };
574
GetFunctionBody(Handle h)575 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
576 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
577 if (local_handle == kInvalidLocalHandle) {
578 LOG(ERROR) << "Could not find Handle: " << h
579 << " on device: " << device_name_;
580 return nullptr;
581 }
582
583 tf_shared_lock l(mu_);
584 auto iter = items_->find(local_handle);
585 CHECK(iter != items_->end());
586 return iter->second->func_graph;
587 }
588
GetRetTypes(Handle h,DataTypeVector * ret_types)589 Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h,
590 DataTypeVector* ret_types) {
591 if (parent_->IsMultiDevice(h)) {
592 return parent_->GetRetTypes(h, ret_types);
593 }
594 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
595 if (local_handle == kInvalidLocalHandle) {
596 return errors::InvalidArgument("Handle ", h, " not found.");
597 }
598 const FunctionBody* fbody = GetFunctionBody(h);
599 *ret_types = fbody->ret_types;
600 return Status::OK();
601 }
602
CreateKernel(const std::shared_ptr<const NodeProperties> & props,OpKernel ** kernel)603 Status FunctionLibraryRuntimeImpl::CreateKernel(
604 const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) {
605 return CreateKernel(props, this, kernel);
606 }
607
CreateKernel(const std::shared_ptr<const NodeProperties> & props,FunctionLibraryRuntime * flr,OpKernel ** kernel)608 Status FunctionLibraryRuntimeImpl::CreateKernel(
609 const std::shared_ptr<const NodeProperties>& props,
610 FunctionLibraryRuntime* flr, OpKernel** kernel) {
611 // If a custom kernel creator is given, try that.
612 Status s;
613 const CustomKernelCreator* custom_kernel_creator =
614 GetDefaultCustomKernelCreator();
615 if (custom_kernel_creator &&
616 custom_kernel_creator->CanCreateKernel(*flr, props)) {
617 std::unique_ptr<OpKernel> ret;
618 s = custom_kernel_creator->CreateKernel(flr, props, &ret);
619 if (s.ok()) {
620 *kernel = ret.release();
621 } else {
622 VLOG(2) << "Custom creator error: " << s;
623 }
624 return s;
625 }
626
627 const FunctionLibraryDefinition* lib_def =
628 flr->GetFunctionLibraryDefinition();
629 if (lib_def->Find(props->node_def.op()) == nullptr) {
630 // A primitive operation. Creates the registered kernel.
631 return CreateNonCachedKernel(device_, flr, props, graph_def_version_,
632 kernel);
633 }
634
635 // Try to instantiate this function for the func/attr. Maybe it's
636 // cached already.
637 InstantiateOptions options;
638 if (lib_def != base_lib_def_) {
639 options.lib_def = lib_def;
640 }
641 Handle handle;
642 TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(),
643 AttrSlice(&props->node_def.attr()), options,
644 &handle));
645
646 const FunctionBody* fbody = GetFunctionBody(handle);
647 CHECK_NOTNULL(fbody);
648
649 // TODO(zhifengc): For now, we assume int32 and resources are always on host
650 // memory and other types are always on device memory. We should do type
651 // inference over function body to derive the correct input/output memory
652 // types.
653 MemoryTypeVector input_memory_types;
654 for (const auto& t : fbody->arg_types) {
655 input_memory_types.push_back(MTypeFromDType(t));
656 }
657 MemoryTypeVector output_memory_types;
658 for (const auto& t : fbody->ret_types) {
659 output_memory_types.push_back(MTypeFromDType(t));
660 }
661
662 // Constructs a CallOp kernel for running the instantiated function.
663 auto device_type = DeviceType(device_->attributes().device_type());
664 auto new_props = std::make_shared<NodeProperties>(
665 &fbody->fdef.signature(), props->node_def, fbody->arg_types,
666 fbody->ret_types);
667 OpKernelConstruction construction(
668 device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
669 device_->resource_manager(), props, input_memory_types,
670 output_memory_types, graph_def_version_, &s);
671 if (s.ok()) {
672 *kernel = new CallOp(handle, &construction);
673 }
674 return s;
675 }
676
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * fbody)677 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
678 const FunctionDef& fdef, AttrSlice attrs,
679 const FunctionLibraryDefinition* lib_def,
680 std::unique_ptr<FunctionBody>* fbody) {
681 if (lib_def == base_lib_def_) {
682 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
683 } else {
684 auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
685 return lib_def->LookUpOpDef(op, sig);
686 };
687 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
688 }
689 }
690
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * g_body)691 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
692 const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
693 std::unique_ptr<FunctionBody>* g_body) {
694 const FunctionDef* fdef = lib_def->Find(func.name());
695 if (fdef == nullptr) {
696 // f is a primitive op.
697 gradient::Creator creator;
698 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
699 if (creator == nullptr) {
700 return errors::InvalidArgument("No gradient is defined for ",
701 func.name());
702 }
703 FunctionDef grad_fdef;
704 // TODO(josh11b): Should filter out the attrs from func that aren't used
705 // by the gradient function.
706 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
707 TF_RETURN_IF_ERROR(
708 FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
709 } else {
710 // f is a user-defined function.
711 InstantiateOptions options;
712 if (lib_def != base_lib_def_) {
713 options.lib_def = lib_def;
714 }
715 Handle f_handle;
716 TF_RETURN_IF_ERROR(
717 Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
718 const FunctionBody* f_body = GetFunctionBody(f_handle);
719 CHECK_NOTNULL(f_body);
720 *g_body = SymbolicGradient(*f_body);
721 }
722 return Status::OK();
723 }
724
IsLocalTarget(const InstantiateOptions & options) const725 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
726 const InstantiateOptions& options) const {
727 if (device_ == nullptr) return true;
728 if (options.target.empty()) return true;
729 if (options.is_multi_device_function) return false;
730 Device* target_device;
731 if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
732 VLOG(1) << "Not instantiating function in FLR because failed to "
733 << "find device " << options.target << " in device manager";
734 return false;
735 }
736 if (target_device != device_) {
737 VLOG(1) << "Not instantiating function in FLR because target device "
738 << options.target
739 << " is different from FLR's device: " << device_->DebugString();
740 return false;
741 }
742 return true;
743 }
744
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)745 Status FunctionLibraryRuntimeImpl::Instantiate(
746 const string& function_name, AttrSlice attrs,
747 const InstantiateOptions& options, Handle* handle) {
748 if (!IsLocalTarget(options)) {
749 return parent_->Instantiate(function_name, attrs, options, handle);
750 }
751
752 if (options.use_function_cache) {
753 InstantiateOptions options_copy(options);
754 options_copy.use_function_cache = false;
755 return function_handle_cache_->Instantiate(function_name, attrs,
756 options_copy, handle);
757 }
758
759 // Since this is a local target, ensure that the local `device_name_` appears
760 // in the canonical key.
761 InstantiateOptions options_copy(options);
762 options_copy.target = device_name_;
763 const string key = Canonicalize(function_name, attrs, options_copy);
764
765 {
766 mutex_lock l(mu_);
767 *handle = parent_->GetHandle(key);
768 if (*handle != kInvalidHandle) {
769 FunctionLibraryRuntime::LocalHandle handle_on_device =
770 parent_->GetHandleOnDevice(device_name_, *handle);
771 if (handle_on_device == kInvalidLocalHandle) {
772 return errors::Internal("LocalHandle not found for handle ", *handle,
773 ".");
774 }
775 auto item_handle = items_->find(handle_on_device);
776 if (item_handle == items_->end()) {
777 return errors::Internal("LocalHandle ", handle_on_device,
778 " for handle ", *handle,
779 " not found in items.");
780 }
781 ++item_handle->second->instantiation_counter;
782 return Status::OK();
783 }
784 }
785
786 const FunctionLibraryDefinition* lib_def =
787 options.lib_def ? options.lib_def : base_lib_def_;
788 std::unique_ptr<FunctionBody> fbody;
789 if (function_name == kGradientOp) {
790 const AttrValue* f = attrs.Find(kFuncAttr);
791 if (f == nullptr) {
792 return errors::InvalidArgument("SymbolicGradient is missing attr: f");
793 }
794 const auto& func = f->func();
795 if (func.name() == kGradientOp) {
796 return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
797 }
798 const string grad = lib_def->FindGradient(func.name());
799 if (!grad.empty()) {
800 return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
801 }
802 TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
803 } else {
804 const FunctionDef* fdef = lib_def->Find(function_name);
805 if (fdef == nullptr) {
806 return errors::NotFound("Function ", function_name, " is not defined.");
807 }
808 TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
809 }
810
811 LocalHandle local_handle;
812 {
813 mutex_lock l(mu_);
814 *handle = parent_->GetHandle(key);
815 if (*handle != kInvalidHandle) {
816 local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
817 ++(*items_)[local_handle]->instantiation_counter;
818 } else {
819 *handle = parent_->AddHandle(key, device_name_, next_handle_);
820 Item* item = new Item;
821 item->func_graph = fbody.release();
822 item->instantiation_counter = 1;
823 item->executor_type = ExecutorType(options, attrs);
824 if (options.lib_def) {
825 item->overlay_flr =
826 new FunctionLibraryRuntimeOverlay(this, options.lib_def);
827 }
828 local_handle = next_handle_++;
829 items_->emplace(local_handle, std::unique_ptr<Item>(item));
830 }
831 }
832
833 if (options.create_kernels_eagerly) {
834 Item* item;
835 TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
836 }
837
838 return Status::OK();
839 }
840
ReleaseHandle(Handle handle)841 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
842 LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
843 if (h == kInvalidLocalHandle) {
844 return parent_->ReleaseHandle(handle);
845 }
846 std::unique_ptr<Item> item_to_delete;
847 Status parent_status;
848 {
849 mutex_lock l(mu_);
850 // Return directly if all items has already been released.
851 if (items_ == nullptr) return Status::OK();
852
853 auto it = items_->find(h);
854 if (it == items_->end()) {
855 return errors::Internal(
856 "Inconsistent FunctionLibraryRuntime. Expected to find an item for "
857 "handle ",
858 h, " but found none");
859 }
860 std::unique_ptr<Item>& item = it->second;
861 --item->instantiation_counter;
862 if (item->instantiation_counter == 0) {
863 // We don't simply erase h's item because that would trigger
864 // item destruction while holding mu_. Item destruction can
865 // trigger graph destruction. If the graph contains kernels like
866 // CallOp or PartitionCallOp, their destructors will release cached
867 // function handles, resulting in deadlock here.
868 item_to_delete = std::move(item);
869 items_->erase(h);
870 parent_status = parent_->RemoveHandle(handle);
871 }
872 }
873 return parent_status;
874 }
875
876 namespace {
877
878 // Removes all stateless nodes that do not contribute to a return
879 // value from the function body. Unlike `RemoveDeadNodes()`, which is
880 // triggered by `OptimizerOptions.do_function_inlining`, this pass
881 // ignores the SINK node, from which (by definition) all nodes are
882 // reverse reachable, and preserves all nodes that are reachable from
883 // control output nodes.
884 //
885 // TODO(ezhulenev, skyewm): Function body should not have special treatment of
886 // stateful ops, graph should encode nodes that must execute with `control_ret`
887 // and `control_output`.
PruneFunctionBody(const FunctionDef & fdef,Graph * g)888 void PruneFunctionBody(const FunctionDef& fdef, Graph* g) {
889 VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name();
890
891 // `control_ret` nodes must be always executed.
892 std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes;
893 for (const auto& control_ret : fdef.control_ret()) {
894 control_ret_nodes.insert(control_ret.second);
895 }
896
897 std::unordered_set<const Node*> nodes;
898 for (auto n : g->nodes()) {
899 // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
900 // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
901 // specifically exclude them as seeds, to avoid unconditionally executing
902 // unused argument nodes (e.g. in a function like `lambda x, y: y`).
903 // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
904 // still needed. It would be preferable to prune entire loops and/or
905 // conditionals if they are not used in the graph.
906 if (n->IsControlFlow() ||
907 (n->op_def().is_stateful() && n->type_string() != kArgOp) ||
908 (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) {
909 nodes.insert(n);
910 }
911 }
912 bool changed = PruneForReverseReachability(g, std::move(nodes));
913 if (changed) {
914 FixupSourceAndSinkEdges(g);
915 }
916 }
917
918 constexpr int kMaxNodesForSingleThreadedExecutor = 32;
919
920 // Returns true if the given operation is suitable to execute via
921 // SingleThreadedExecutor. This is an intentional subset of the ops which
922 // technically can be run via single-threaded execution to avoid issues with
923 // recursion or function invocation.
924 //
925 // SingleThreadedExecutor runs asynchronous kernels synchronously: this can lead
926 // to deadlocks. This function attempts to exclude all async kernels in lieu of
927 // kernel instantiation.
IsOpSingleThreadedExecutorCompatible(const Node & n)928 bool IsOpSingleThreadedExecutorCompatible(const Node& n) {
929 if (n.IsFunctionCall() || n.IsPartitionedCall() || n.IsIfNode() ||
930 n.IsWhileNode() || n.IsCaseNode()) {
931 return false;
932 }
933 if (n.IsControlFlow()) {
934 return false;
935 }
936 if (n.IsSend() || n.IsHostSend() || n.IsRecv() || n.IsHostRecv()) {
937 return false;
938 }
939 if (n.IsCollective()) {
940 return false;
941 }
942 for (DataType dt : n.output_types()) {
943 if (IsRefType(dt)) {
944 return false;
945 }
946 }
947 std::string lower = str_util::Lowercase(n.op_def().name());
948 if (str_util::StrContains(lower, "pyfunc") ||
949 str_util::StrContains(lower, "queue") ||
950 str_util::StrContains(lower, "rpc")) {
951 return false;
952 }
953
954 return true;
955 }
956
957 // Returns true if the given Graph is safe & efficient to run via the single
958 // threaded executor. The single-threaded executor has lower dispatch overhead
959 // for simple functions.
960 //
961 // This currently specializes for the case of a single operation, as created
962 // via eager execution.
IsSingleThreadedExecutorCompatible(const Graph * g)963 bool IsSingleThreadedExecutorCompatible(const Graph* g) {
964 // TODO(b/187729969): Temporarily disabled due to b/187306798.
965 return false;
966
967 // Not worth analyzing large graphs.
968 if (g->num_nodes() > kMaxNodesForSingleThreadedExecutor) {
969 return false;
970 }
971
972 int count = 0;
973 for (Node* n : g->nodes()) {
974 if (!IsOpSingleThreadedExecutorCompatible(*n)) {
975 return false;
976 }
977 if (n->op_def().name() == "_Arg" || n->op_def().name() == "_Retval" ||
978 n->op_def().name() == "NoOp") {
979 continue;
980 }
981
982 count += 1;
983 }
984
985 if (count == 1) {
986 return true;
987 }
988
989 return false;
990 }
991
992 } // namespace
993
CreateItem(Item ** item)994 Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
995 const FunctionBody* fbody;
996 FunctionLibraryRuntime* flr;
997 string executor_type;
998 {
999 tf_shared_lock l(mu_);
1000 fbody = (*item)->func_graph;
1001 flr = (*item)->overlay_flr
1002 ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr)
1003 : static_cast<FunctionLibraryRuntime*>(this);
1004 executor_type = (*item)->executor_type;
1005 }
1006 const FunctionLibraryDefinition* lib_def =
1007 flr->GetFunctionLibraryDefinition();
1008 auto g = absl::make_unique<Graph>(lib_def);
1009 CopyGraph(*fbody->graph, g.get());
1010
1011 PruneFunctionBody(fbody->fdef, g.get());
1012 optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
1013 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
1014 device()->name(), g.get()));
1015
1016 // Creates an executor based on the g. This must be done without
1017 // holding mu_ because create_kernel_ calls back into the library.
1018 LocalExecutorParams params;
1019 params.device = device_;
1020 params.function_library = flr;
1021 if (flr == this) {
1022 params.create_kernel = create_kernel_;
1023 } else {
1024 params.create_kernel =
1025 [this, flr](const std::shared_ptr<const NodeProperties>& props,
1026 OpKernel** kernel) {
1027 return CreateKernel(props, flr, kernel);
1028 };
1029 }
1030 params.delete_kernel = [](OpKernel* kernel) {
1031 DeleteNonCachedKernel(kernel);
1032 };
1033 params.session_metadata = session_metadata_;
1034 std::unique_ptr<Executor> exec;
1035
1036 if (executor_type.empty() && IsSingleThreadedExecutorCompatible(g.get())) {
1037 executor_type = "SINGLE_THREADED_EXECUTOR";
1038 }
1039 TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
1040 {
1041 // Guard item since it is already inserted in items_.
1042 mutex_lock l(mu_);
1043 if ((*item)->exec == nullptr) {
1044 (*item)->graph = std::move(g);
1045 (*item)->exec = exec.release();
1046 }
1047 }
1048 return Status::OK();
1049 }
1050
GetOrCreateItem(LocalHandle local_handle,Item ** item)1051 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
1052 Item** item) {
1053 {
1054 tf_shared_lock l(mu_);
1055 auto iter = items_->find(local_handle);
1056 if (iter == items_->end()) {
1057 return errors::Internal("Local function handle ", local_handle,
1058 " is not valid. Likely an internal error.");
1059 }
1060 *item = iter->second.get();
1061 if ((*item)->exec != nullptr) {
1062 return Status::OK();
1063 }
1064 }
1065 // NOTE: We need to call CreateItem out of mu_ because creating an
1066 // executor needs to call CreateKernel.
1067 return CreateItem(item);
1068 }
1069
ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options & run_opts,CallFrameInterface * frame,Executor::Args * exec_args)1070 void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
1071 const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
1072 Executor::Args* exec_args) {
1073 // Inherit the step_id from the caller.
1074 exec_args->step_id = run_opts.step_id;
1075 exec_args->rendezvous = run_opts.rendezvous;
1076 exec_args->stats_collector = run_opts.stats_collector;
1077 exec_args->cancellation_manager = run_opts.cancellation_manager;
1078 exec_args->step_container = run_opts.step_container;
1079 if (run_opts.runner) {
1080 exec_args->runner = *run_opts.runner;
1081 } else {
1082 exec_args->runner = default_runner_;
1083 }
1084 exec_args->collective_executor = run_opts.collective_executor;
1085 exec_args->call_frame = frame;
1086 exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline;
1087 exec_args->user_intra_op_threadpool = run_opts.user_intra_op_threadpool;
1088 exec_args->coordination_service_agent = run_opts.coordination_service_agent;
1089 }
1090
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Item * item,DoneCallback done)1091 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
1092 gtl::ArraySlice<Tensor> args,
1093 std::vector<Tensor>* rets,
1094 Item* item, DoneCallback done) {
1095 string target_device = parent_->GetDeviceName(handle);
1096 string source_device = opts.source_device;
1097 RendezvousInterface* rendezvous = opts.rendezvous;
1098 DeviceContext* device_context;
1099 Status s = parent_->GetDeviceContext(target_device, &device_context);
1100 if (!s.ok()) {
1101 done(s);
1102 return;
1103 }
1104 int64_t src_incarnation, target_incarnation;
1105 s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
1106 s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
1107 if (!s.ok()) {
1108 done(s);
1109 return;
1110 }
1111
1112 const FunctionBody* fbody = GetFunctionBody(handle);
1113 FunctionCallFrame* frame =
1114 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1115 Executor::Args* exec_args = new Executor::Args;
1116 ExecutorArgsFromOptions(opts, frame, exec_args);
1117
1118 std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
1119 args_alloc_attrs.reserve(fbody->arg_types.size());
1120 rets_alloc_attrs.reserve(fbody->ret_types.size());
1121 // Note: Functions assume that int32's are always on host memory.
1122 for (const auto& arg_type : fbody->arg_types) {
1123 AllocatorAttributes arg_alloc_attrs;
1124 if (MTypeFromDType(arg_type) == HOST_MEMORY) {
1125 arg_alloc_attrs.set_on_host(true);
1126 }
1127 args_alloc_attrs.push_back(arg_alloc_attrs);
1128 }
1129 for (const auto& ret_type : fbody->ret_types) {
1130 AllocatorAttributes ret_alloc_attrs;
1131 if (MTypeFromDType(ret_type) == HOST_MEMORY) {
1132 ret_alloc_attrs.set_on_host(true);
1133 }
1134 rets_alloc_attrs.push_back(ret_alloc_attrs);
1135 }
1136
1137 bool allow_dead_tensors = opts.allow_dead_tensors;
1138
1139 // The ProcFLR sends the arguments to the function from the source_device to
1140 // the target_device. So here we receive those arguments. Similarly, when the
1141 // computation is done and stored in *rets, we send the return values back
1142 // to the source_device (caller) so that the ProcFLR can receive them later.
1143 std::vector<Tensor>* remote_args = new std::vector<Tensor>;
1144 ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
1145 source_device, target_device, "arg_", src_incarnation, args.size(),
1146 device_context, args_alloc_attrs, rendezvous, remote_args,
1147 [frame, remote_args, item, source_device, target_device,
1148 target_incarnation, rendezvous, device_context, rets, done, exec_args,
1149 rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1150 Status s = status;
1151 if (s.ok()) {
1152 s = frame->SetArgs(*remote_args);
1153 }
1154 if (!s.ok()) {
1155 delete frame;
1156 delete remote_args;
1157 delete exec_args;
1158 done(s);
1159 return;
1160 }
1161 item->exec->RunAsync(
1162 *exec_args,
1163 [frame, rets, done, source_device, target_device,
1164 target_incarnation, rendezvous, device_context, remote_args,
1165 rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1166 Status s = status;
1167 if (s.ok()) {
1168 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1169 }
1170 delete frame;
1171 if (!s.ok()) {
1172 delete remote_args;
1173 done(s);
1174 return;
1175 }
1176 s = ProcessFunctionLibraryRuntime::SendTensors(
1177 target_device, source_device, "ret_", target_incarnation,
1178 *rets, device_context, rets_alloc_attrs, rendezvous);
1179 delete remote_args;
1180 done(s);
1181 });
1182 delete exec_args;
1183 });
1184 }
1185
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)1186 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1187 gtl::ArraySlice<Tensor> args,
1188 std::vector<Tensor>* rets,
1189 DoneCallback done) {
1190 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1191 done(errors::Cancelled("Function was cancelled before it was started"));
1192 return;
1193 }
1194 Options run_opts = opts;
1195 if (opts.create_rendezvous) {
1196 auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
1197 run_opts.rendezvous = rendezvous;
1198 run_opts.create_rendezvous = false;
1199 done = [done = std::move(done), rendezvous](const Status& status) mutable {
1200 delete rendezvous;
1201 done(status);
1202 };
1203 }
1204
1205 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1206 if (local_handle == kInvalidLocalHandle) {
1207 parent_->Run(run_opts, handle, args, rets, done);
1208 return;
1209 }
1210
1211 if (run_opts.runner == nullptr) {
1212 run_opts.runner = &default_runner_;
1213 }
1214 DCHECK(run_opts.runner != nullptr);
1215
1216 Item* item = nullptr;
1217 Status s = GetOrCreateItem(local_handle, &item);
1218 if (!s.ok()) {
1219 done(s);
1220 return;
1221 }
1222
1223 if (run_opts.remote_execution) {
1224 // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
1225 RunRemote(run_opts, handle, args, rets, item, std::move(done));
1226 return;
1227 }
1228
1229 const FunctionBody* fbody = GetFunctionBody(handle);
1230 FunctionCallFrame* frame =
1231 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1232 s = frame->SetArgs(args);
1233 if (!s.ok()) {
1234 delete frame;
1235 done(s);
1236 return;
1237 }
1238
1239 profiler::TraceMeProducer activity(
1240 // To TraceMeConsumers in ExecutorState::Process/Finish.
1241 [&opts] {
1242 return profiler::TraceMeEncode("FunctionRun",
1243 {{"id", opts.step_id}, {"_r", 1}});
1244 },
1245 profiler::ContextType::kTfExecutor, opts.step_id,
1246 profiler::TraceMeLevel::kInfo);
1247
1248 Executor::Args exec_args;
1249 ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1250
1251 bool allow_dead_tensors = run_opts.allow_dead_tensors;
1252 item->exec->RunAsync(
1253 // Executor args
1254 exec_args,
1255 // Done callback.
1256 [frame, rets, done, allow_dead_tensors](const Status& status) {
1257 Status s = status;
1258 if (s.ok()) {
1259 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1260 }
1261 delete frame;
1262 done(s);
1263 });
1264 }
1265
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)1266 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1267 CallFrameInterface* frame,
1268 DoneCallback done) {
1269 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1270 done(errors::Cancelled(""));
1271 return;
1272 }
1273
1274 Options run_opts = opts;
1275 if (opts.create_rendezvous) {
1276 auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
1277 run_opts.rendezvous = rendezvous;
1278 run_opts.create_rendezvous = false;
1279 done = [done = std::move(done), rendezvous](const Status& status) mutable {
1280 delete rendezvous;
1281 done(status);
1282 };
1283 }
1284
1285 LocalHandle local_handle = parent_->GetHandleOnDevice(
1286 device_name_, handle, /*include_multi_device=*/true);
1287 if (local_handle == kInvalidLocalHandle) {
1288 parent_->Run(run_opts, handle, frame, done);
1289 return;
1290 }
1291
1292 if (opts.remote_execution) {
1293 // NOTE(mrry): This bit is only set for a local function when `parent_`
1294 // calls back into this class, and the current implementation of
1295 // `ProcessFunctionLibraryRuntime` currently always uses the vector-based
1296 // `args`/`rets` interface.
1297 done(errors::Unimplemented("Remote calling with CallFrameInterface"));
1298 return;
1299 }
1300
1301 Item* item = nullptr;
1302 Status s = GetOrCreateItem(local_handle, &item);
1303 if (!s.ok()) {
1304 done(s);
1305 return;
1306 }
1307 if (run_opts.runner == nullptr) {
1308 run_opts.runner = &default_runner_;
1309 }
1310 DCHECK(run_opts.runner != nullptr);
1311
1312 profiler::TraceMeProducer activity(
1313 // To TraceMeConsumers in ExecutorState::Process/Finish.
1314 [&opts] {
1315 return profiler::TraceMeEncode("FunctionRun",
1316 {{"id", opts.step_id}, {"_r", 1}});
1317 },
1318 profiler::ContextType::kTfExecutor, opts.step_id,
1319 profiler::TraceMeLevel::kInfo);
1320
1321 Executor::Args exec_args;
1322 ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1323 item->exec->RunAsync(exec_args, std::move(done));
1324 }
1325
PrepareRunSync(Handle handle,Options * run_opts,Item ** out_item,std::unique_ptr<PrivateIntraProcessRendezvous> * out_rendezvous)1326 Status FunctionLibraryRuntimeImpl::PrepareRunSync(
1327 Handle handle, Options* run_opts, Item** out_item,
1328 std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous) {
1329 if (run_opts->cancellation_manager &&
1330 run_opts->cancellation_manager->IsCancelled()) {
1331 return errors::Cancelled("");
1332 }
1333
1334 if (run_opts->remote_execution) {
1335 // NOTE(mrry): This bit is only set for a local function when `parent_`
1336 // calls back into this class, and the current implementation of
1337 // `ProcessFunctionLibraryRuntime` currently always uses the asynchronous
1338 // Run() method.
1339 return errors::Unimplemented("Remote calling with RunSync()");
1340 }
1341
1342 if (run_opts->create_rendezvous) {
1343 *out_rendezvous =
1344 absl::make_unique<PrivateIntraProcessRendezvous>(device_mgr_);
1345 run_opts->rendezvous = out_rendezvous->get();
1346 run_opts->create_rendezvous = false;
1347 }
1348
1349 LocalHandle local_handle = parent_->GetHandleOnDevice(
1350 device_name_, handle, /*include_multi_device=*/true);
1351 if (local_handle == kInvalidLocalHandle) {
1352 *out_item = nullptr;
1353 return Status::OK();
1354 }
1355
1356 TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item));
1357
1358 if (run_opts->runner == nullptr) {
1359 run_opts->runner = &default_runner_;
1360 }
1361 DCHECK(run_opts->runner != nullptr);
1362
1363 return Status::OK();
1364 }
1365
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)1366 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1367 gtl::ArraySlice<Tensor> args,
1368 std::vector<Tensor>* rets) {
1369 Item* item = nullptr;
1370 std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1371 TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1372 if (item == nullptr) {
1373 return parent_->RunSync(opts, handle, args, rets);
1374 }
1375
1376 Executor::Args exec_args;
1377 const FunctionBody* fbody = GetFunctionBody(handle);
1378 FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
1379 TF_RETURN_IF_ERROR(frame.SetArgs(args));
1380 ExecutorArgsFromOptions(opts, &frame, &exec_args);
1381
1382 TF_RETURN_IF_ERROR(item->exec->Run(exec_args));
1383 return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
1384 }
1385
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)1386 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1387 CallFrameInterface* call_frame) {
1388 Item* item = nullptr;
1389 std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1390 TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1391 if (item == nullptr) {
1392 return parent_->RunSync(opts, handle, call_frame);
1393 }
1394
1395 Executor::Args exec_args;
1396 ExecutorArgsFromOptions(opts, call_frame, &exec_args);
1397 return item->exec->Run(exec_args);
1398 }
1399
IsStateful(const string & func) const1400 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
1401 const OpDef* op_def;
1402 const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
1403 return s.ok() && op_def->is_stateful();
1404 }
1405
DebugString(Handle handle)1406 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
1407 Item* item = nullptr;
1408 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1409 Status s = GetOrCreateItem(local_handle, &item);
1410 if (s.ok()) {
1411 if (item->graph) {
1412 return tensorflow::DebugString(item->graph.get());
1413 } else {
1414 return tensorflow::DebugString(item->func_graph->graph);
1415 }
1416 } else {
1417 return s.ToString();
1418 }
1419 }
1420
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)1421 Status FunctionLibraryRuntimeImpl::Clone(
1422 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1423 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1424 FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
1425 TF_RETURN_IF_ERROR(parent_->Clone(env_, graph_def_version_,
1426 optimizer_.options(), out_lib_def, out_pflr,
1427 skip_flib_def));
1428 *out_flr = (*out_pflr)->GetFLR(device_->name());
1429 if (*out_flr != nullptr) {
1430 return Status::OK();
1431 } else {
1432 return errors::Internal("Cloning FunctionLibraryRuntime failed.");
1433 }
1434 }
1435
1436 namespace {
1437
1438 struct CustomCreatorSingleton {
1439 mutex mu;
1440 CustomKernelCreator* custom_creator = nullptr;
1441
Settensorflow::__anon4439174e1111::CustomCreatorSingleton1442 void Set(CustomKernelCreator* cb) {
1443 mutex_lock l(mu);
1444 custom_creator = cb;
1445 }
1446
Gettensorflow::__anon4439174e1111::CustomCreatorSingleton1447 CustomKernelCreator* Get() {
1448 mutex_lock l(mu);
1449 return custom_creator;
1450 }
1451 };
1452
GetCustomCreatorSingleton()1453 CustomCreatorSingleton* GetCustomCreatorSingleton() {
1454 static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
1455 return ccs;
1456 }
1457
1458 } // namespace
1459
GetDefaultCustomKernelCreator()1460 const CustomKernelCreator* GetDefaultCustomKernelCreator() {
1461 return GetCustomCreatorSingleton()->Get();
1462 }
1463
RegisterDefaultCustomKernelCreator(CustomKernelCreator * c)1464 void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
1465 GetCustomCreatorSingleton()->Set(c);
1466 }
1467
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)1468 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1469 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
1470 Device* device, int graph_def_version,
1471 const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
1472 const OptimizerOptions& optimizer_options,
1473 const SessionMetadata* session_metadata,
1474 ProcessFunctionLibraryRuntime* parent) {
1475 return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
1476 device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
1477 optimizer_options, session_metadata, parent));
1478 }
1479
1480 class SymbolicGradientHelper {
1481 public:
SymbolicGradientHelper(const FunctionBody & f)1482 explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1483 ~SymbolicGradientHelper() = default;
1484
1485 std::unique_ptr<FunctionBody> Compute();
1486
1487 private:
1488 const FunctionBody* fbody_;
1489
1490 // Makes a copy of fbody_ in gbody.
1491 void Copy(FunctionBody* gbody);
1492
1493 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1494 };
1495
Copy(FunctionBody * gbody)1496 void SymbolicGradientHelper::Copy(FunctionBody* gbody) {
1497 const Graph& src = *(fbody_->graph);
1498 gbody->graph = new Graph(src.op_registry());
1499 Graph* dst = gbody->graph;
1500
1501 std::vector<Node*> node_map(src.num_node_ids());
1502
1503 // Copy just the fdef attributes (copy '_noinline' and other similar flags to
1504 // the gradient function body).
1505 *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr();
1506
1507 // Copy the nodes.
1508 node_map[src.source_node()->id()] = dst->source_node();
1509 node_map[src.sink_node()->id()] = dst->sink_node();
1510 for (Node* n : src.op_nodes()) {
1511 node_map[n->id()] = dst->CopyNode(n);
1512 }
1513
1514 // Copy the edges.
1515 for (const Edge* e : src.edges()) {
1516 Node* src_copy = node_map[e->src()->id()];
1517 Node* dst_copy = node_map[e->dst()->id()];
1518 dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1519 }
1520
1521 // Save inputs in copied graph.
1522 CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1523 gbody->arg_types = fbody_->arg_types;
1524 for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1525 gbody->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1526 }
1527
1528 // Save outputs in copied graph.
1529 CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1530 gbody->ret_types = fbody_->ret_types;
1531 for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1532 gbody->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1533 }
1534 }
1535
Compute()1536 std::unique_ptr<FunctionBody> SymbolicGradientHelper::Compute() {
1537 FunctionBody* gbody = new FunctionBody;
1538 Copy(gbody); // copy fbody_ into gbody.
1539
1540 Graph* g = gbody->graph;
1541
1542 const int num_y = static_cast<int>(gbody->ret_nodes.size());
1543
1544 // Populate 'y_node_outputs_' with node function body outputs.
1545 // Populate 'y_grad_nodes' with initial gradient nodes for each return node
1546 // of the original function body (these will be 'arg' nodes in the function
1547 // gradient body).
1548 std::vector<NodeOut> y_node_outputs;
1549 y_node_outputs.reserve(num_y);
1550 std::vector<NodeOut> y_grad_node_outputs;
1551 y_grad_node_outputs.reserve(num_y);
1552 for (int i = 0; i < num_y; ++i) {
1553 Node* y = gbody->ret_nodes[i];
1554 y_node_outputs.push_back({y, 0});
1555 DCHECK_EQ(y->type_string(), kRetOp);
1556 const DataType dtype = y->input_type(0);
1557 const int index = static_cast<int>(gbody->arg_nodes.size());
1558 Node* dy = AddArg(g, dtype, index);
1559 gbody->arg_types.push_back(dtype);
1560 gbody->arg_nodes.push_back(dy);
1561 y_grad_node_outputs.push_back({dy, 0});
1562 }
1563
1564 // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1565 const size_t num_x = fbody_->arg_nodes.size();
1566 std::vector<NodeOut> x_node_outputs;
1567 x_node_outputs.reserve(num_x);
1568 for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1569 x_node_outputs.push_back({gbody->arg_nodes[i], 0});
1570 }
1571
1572 // Call AddSymbolicGradients which will add nodes to graph 'g' that
1573 // compute the function gradient (adding an entry in 'x_grad_node_outputs'
1574 // for each node in 'x_node_outputs').
1575 std::vector<NodeOut> x_grad_node_outputs;
1576 TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1577 y_grad_node_outputs, &x_grad_node_outputs,
1578 g));
1579
1580 // Remove the old return nodes from the function body.
1581 for (Node* n : gbody->ret_nodes) {
1582 g->RemoveNode(n);
1583 }
1584 gbody->ret_types = fbody_->arg_types;
1585 // TODO(apassos): use the right dtype for gradients of resource variables
1586 for (int i = 0; i < gbody->ret_types.size(); ++i) {
1587 if (gbody->ret_types[i] == DT_RESOURCE) {
1588 gbody->ret_types[i] = DT_FLOAT;
1589 }
1590 }
1591 gbody->ret_nodes.clear();
1592 // Add new return nodes to the function gradient body for each node
1593 // in 'x_grad_nodes'.
1594 const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1595 for (int i = 0; i < arg_types_size; ++i) {
1596 Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1597 Node* ret = AddRet(g, grad, i);
1598 gbody->ret_nodes.push_back(ret);
1599 }
1600
1601 return std::unique_ptr<FunctionBody>(gbody);
1602 }
1603
SymbolicGradient(const FunctionBody & f)1604 std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
1605 return SymbolicGradientHelper(f).Compute();
1606 }
1607
1608 } // end namespace tensorflow
1609