1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/function.h"
17
18 #include <deque>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/common_runtime/executor.h"
25 #include "tensorflow/core/common_runtime/executor_factory.h"
26 #include "tensorflow/core/common_runtime/graph_optimizer.h"
27 #include "tensorflow/core/common_runtime/memory_types.h"
28 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
29 #include "tensorflow/core/framework/collective.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/graph/algorithm.h"
37 #include "tensorflow/core/graph/control_flow.h"
38 #include "tensorflow/core/graph/gradients.h"
39 #include "tensorflow/core/graph/graph_constructor.h"
40 #include "tensorflow/core/graph/optimizer_cse.h"
41 #include "tensorflow/core/lib/core/threadpool.h"
42 #include "tensorflow/core/lib/gtl/map_util.h"
43 #include "tensorflow/core/platform/macros.h"
44
45 // See core/kernels/function_ops.cc for related kernels.
46
47 namespace tensorflow {
48
49 // A few string constant used throughout this module.
50 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
51 static constexpr const char* const kDeviceArgOp =
52 FunctionLibraryDefinition::kDeviceArgOp;
53 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
54 static constexpr const char* const kDeviceRetOp =
55 FunctionLibraryDefinition::kDeviceRetOp;
56 static constexpr const char* const kGradientOp =
57 FunctionLibraryDefinition::kGradientOp;
58 static constexpr const char* const kNodeLabel = "Func";
59 static constexpr const char* const kFuncAttr =
60 FunctionLibraryDefinition::kFuncAttr;
61
62 // Represents the index-th output of a node.
63 struct Endpoint {
64 Node* node;
65 int index;
66
67 // Returns the string name represents this endpoint.
nametensorflow::Endpoint68 string name() const {
69 if (index == 0) {
70 return node->name();
71 } else {
72 return strings::StrCat(node->name(), ":", index);
73 }
74 }
75
dtypetensorflow::Endpoint76 DataType dtype() const { return node->output_type(index); }
77 };
78
79 struct EndpointHash {
operator ()tensorflow::EndpointHash80 uint64 operator()(const Endpoint& x) const {
81 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
82 x.index);
83 }
84 };
85
86 struct EndpointEq {
operator ()tensorflow::EndpointEq87 bool operator()(const Endpoint& x, const Endpoint& y) const {
88 return (x.node == y.node) && (x.index == y.index);
89 }
90 };
91
92 // The following Add* routines are used to add a few graph nodes while
93 // functions are transformed.
AddNoOp(StringPiece name,Graph * g)94 static Node* AddNoOp(StringPiece name, Graph* g) {
95 NodeDef ndef;
96 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
97 ndef.set_op("NoOp");
98 Status s;
99 Node* ret = g->AddNode(ndef, &s);
100 TF_CHECK_OK(s);
101 return ret;
102 }
103
AddIdentity(StringPiece name,Graph * g,Endpoint input)104 static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
105 DCHECK_LT(0, input.dtype());
106 NodeDef ndef;
107 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
108 ndef.set_op("Identity");
109 // NOTE(skyewm): we explicitly set the device here to address a multi-GPU
110 // performance issue where this Identity would be placed alone on a GPU,
111 // causing unnecessary device traffic. See b/122483225 for details.
112 ndef.set_device(input.node->def().device());
113 ndef.add_input(input.name());
114 AddNodeAttr("T", BaseType(input.dtype()), &ndef);
115 Status s;
116 Node* ret = g->AddNode(ndef, &s);
117 TF_CHECK_OK(s);
118 g->AddEdge(input.node, input.index, ret, 0);
119 return ret;
120 }
121
AddArg(Graph * g,DataType dtype,int index)122 static Node* AddArg(Graph* g, DataType dtype, int index) {
123 DCHECK_LT(0, dtype);
124 DCHECK_LT(dtype, DT_FLOAT_REF);
125 NodeDef ndef;
126 ndef.set_name(g->NewName(kNodeLabel));
127 ndef.set_op(kArgOp);
128 AddNodeAttr("T", dtype, &ndef);
129 AddNodeAttr("index", index, &ndef);
130 Status s;
131 Node* ret = g->AddNode(ndef, &s);
132 TF_CHECK_OK(s);
133 return ret;
134 }
135
AddRet(Graph * g,Endpoint input,int index)136 static Node* AddRet(Graph* g, Endpoint input, int index) {
137 DCHECK_LT(0, input.dtype());
138 DCHECK_LT(input.dtype(), DT_FLOAT_REF);
139 NodeDef ndef;
140 ndef.set_name(g->NewName(kNodeLabel));
141 ndef.set_op(kRetOp);
142 ndef.add_input(input.name());
143 AddNodeAttr("T", input.dtype(), &ndef);
144 AddNodeAttr("index", index, &ndef);
145 Status s;
146 Node* ret = g->AddNode(ndef, &s);
147 TF_CHECK_OK(s);
148 g->AddEdge(input.node, input.index, ret, 0);
149 return ret;
150 }
151
152 // FunctionLibraryRuntime implementation that forwards all the function calls to
153 // the base runtime implementation, and only overrides overlay lib in calls to
154 // Instantiate (if caller doesn't provide its own overlay lib).
155 //
156 // When function library runtime (FunctionLibraryRuntimeImpl specifically)
157 // instantiates function into a Graph object, it also creates an Executor for
158 // it. That executor has a pointer to the function library runtime instance,
159 // that is used to instantiate all nested function calls.
160 //
161 // If the original function was instantiated using overlay lib, we must preserve
162 // that overlay lib in the executor's function library runtime.
163 //
164 // IMPORTANT: This runtime is intended for use only in executors created for
165 // functions instantiated into a graph in FunctionLibraryRuntimeImpl.
166 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
167 public:
FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime * base_flr,const FunctionLibraryDefinition * overlay_lib_def)168 FunctionLibraryRuntimeOverlay(
169 FunctionLibraryRuntime* base_flr,
170 const FunctionLibraryDefinition* overlay_lib_def)
171 : base_flr_(base_flr), overlay_lib_def_(overlay_lib_def) {}
172 ~FunctionLibraryRuntimeOverlay() override;
173
174 Status Instantiate(const string& function_name, AttrSlice attrs,
175 const InstantiateOptions& options,
176 Handle* handle) override;
177
178 Status ReleaseHandle(Handle handle) override;
179
180 const FunctionBody* GetFunctionBody(Handle h) override;
181
182 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
183 std::vector<Tensor>* rets, DoneCallback done) override;
184
185 void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
186 DoneCallback done) override;
187
188 Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
189
190 bool IsStateful(const string& function_name) override;
191
192 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
193 const override;
194
195 Env* env() override;
196 Device* device() override;
197 std::function<void(std::function<void()>)>* runner() override;
198 const DeviceMgr* device_mgr() const override;
199
200 string DebugString(Handle handle) override;
201 int graph_def_version() override;
202
203 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
204 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
205 FunctionLibraryRuntime** out_flr) override;
206
207 private:
208 FunctionLibraryRuntime* base_flr_; // not owned
209 const FunctionLibraryDefinition* overlay_lib_def_; // not owned
210 };
211
212 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
213
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)214 Status FunctionLibraryRuntimeOverlay::Instantiate(
215 const string& function_name, AttrSlice attrs,
216 const InstantiateOptions& options, Handle* handle) {
217 // We automatically add overlay lib to all instantiations, if the caller
218 // doesn't provide its own override.
219 if (!options.overlay_lib && overlay_lib_def_) {
220 InstantiateOptions options_copy = options;
221 options_copy.overlay_lib = overlay_lib_def_;
222 return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
223 } else {
224 return base_flr_->Instantiate(function_name, attrs, options, handle);
225 }
226 }
227
ReleaseHandle(Handle handle)228 Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) {
229 return base_flr_->ReleaseHandle(handle);
230 }
231
GetFunctionBody(Handle h)232 const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) {
233 return base_flr_->GetFunctionBody(h);
234 }
235
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)236 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
237 gtl::ArraySlice<Tensor> args,
238 std::vector<Tensor>* rets,
239 DoneCallback done) {
240 base_flr_->Run(opts, handle, args, rets, std::move(done));
241 }
242
Run(const Options & opts,Handle handle,CallFrameInterface * call_frame,DoneCallback done)243 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
244 CallFrameInterface* call_frame,
245 DoneCallback done) {
246 base_flr_->Run(opts, handle, call_frame, std::move(done));
247 }
248
CreateKernel(const NodeDef &,OpKernel **)249 Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) {
250 // We don't have access base_lib_def_ in base function library runtime (aka
251 // FunctionLibraryRuntimeImpl), so to make sure we do not create kernel with
252 // wrong lib_def we just disable creation of new kernels through overlays.
253 //
254 // When we call Instantiate from the base runtime with overlay lib override,
255 // the base runtime implementation is responsible for correctly passing custom
256 // overlay lib to all kernel constructions.
257 return errors::Internal(
258 "Overlay function library runtime doesn't support kernel creation.");
259 }
260
IsStateful(const string & function_name)261 bool FunctionLibraryRuntimeOverlay::IsStateful(const string& function_name) {
262 // Important: we do not forward lookup to the base FLR.
263 const OpDef* op_def;
264 const Status s = overlay_lib_def_->LookUpOpDef(function_name, &op_def);
265 return s.ok() && op_def->is_stateful();
266 }
267
env()268 Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
269
device()270 Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
271
272 std::function<void(std::function<void()>)>*
runner()273 FunctionLibraryRuntimeOverlay::runner() {
274 return base_flr_->runner();
275 }
276
device_mgr() const277 const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
278 return base_flr_->device_mgr();
279 }
280
281 const FunctionLibraryDefinition*
GetFunctionLibraryDefinition() const282 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
283 return overlay_lib_def_ ? overlay_lib_def_
284 : base_flr_->GetFunctionLibraryDefinition();
285 }
286
DebugString(Handle handle)287 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
288 return base_flr_->DebugString(handle);
289 }
290
graph_def_version()291 int FunctionLibraryRuntimeOverlay::graph_def_version() {
292 return base_flr_->graph_def_version();
293 }
294
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr)295 Status FunctionLibraryRuntimeOverlay::Clone(
296 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
297 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
298 FunctionLibraryRuntime** out_flr) {
299 // NOTE(ezhulenev): Cloned FunctionLibraryRuntime will be missing overlay lib,
300 // but that's ok because we anyway do not copy/clone instantiated items from
301 // the base FLR.
302 return base_flr_->Clone(out_lib_def, out_pflr, out_flr);
303 }
304
305 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
306 public:
307 FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
308 int graph_def_version,
309 const FunctionLibraryDefinition* lib_def,
310 thread::ThreadPool* default_thread_pool,
311 const OptimizerOptions& optimizer_options,
312 CustomKernelCreator custom_kernel_creator,
313 ProcessFunctionLibraryRuntime* parent);
314
315 ~FunctionLibraryRuntimeImpl() override;
316
317 Status Instantiate(const string& function_name, AttrSlice attrs,
318 const InstantiateOptions& options,
319 Handle* handle) override;
320
321 Status ReleaseHandle(Handle handle) override;
322
323 const FunctionBody* GetFunctionBody(Handle handle) override;
324
325 Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
326
327 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
328 std::vector<Tensor>* rets, DoneCallback done) override;
329 // NOTE(mrry): This overload is currently only implemented for local function
330 // execution.
331 // TODO(b/70346412): Implement support for remote function execution when
332 // passing a call frame.
333 void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
334 DoneCallback done) override;
335
336 bool IsStateful(const string& function) override;
337
GetFunctionLibraryDefinition() const338 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
339 const override {
340 return base_lib_def_;
341 }
342
device()343 Device* device() override { return device_; }
344
runner()345 std::function<void(std::function<void()>)>* runner() override {
346 return &default_runner_;
347 }
348
device_mgr() const349 const DeviceMgr* device_mgr() const override { return device_mgr_; }
env()350 Env* env() override { return env_; }
graph_def_version()351 int graph_def_version() override { return graph_def_version_; }
352
353 string DebugString(Handle h) override;
354
355 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
356 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
357 FunctionLibraryRuntime** out_flr) override;
358
359 private:
360 typedef FunctionLibraryRuntimeImpl ME;
361
362 const DeviceMgr* const device_mgr_;
363 Device* const device_;
364 Env* const env_;
365 const int graph_def_version_;
366 const FunctionLibraryDefinition* const base_lib_def_;
367 GraphOptimizer optimizer_;
368 const CustomKernelCreator custom_kernel_creator_;
369 Executor::Args::Runner default_runner_;
370 const string device_name_;
371
372 std::function<Status(const string&, const OpDef**)> get_func_sig_;
373 std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
374
375 mutable mutex mu_;
376
377 int next_handle_ GUARDED_BY(mu_);
378
379 // The instantiated and transformed function is encoded as a Graph
380 // object, and an executor is created for the graph.
381 struct Item {
382 uint64 instantiation_counter = 0;
383 const Graph* graph = nullptr; // Owned by exec.
384 const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
385 FunctionBody* func_graph = nullptr;
386 Executor* exec = nullptr;
387 FunctionLibraryRuntimeOverlay* overlay_flr = nullptr;
388 string executor_type;
389
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item390 ~Item() {
391 delete this->func_graph;
392 delete this->exec;
393 delete this->overlay_flr;
394 }
395 };
396 std::unordered_map<Handle, std::unique_ptr<Item>> items_ GUARDED_BY(mu_);
397
398 ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
399
400 Status CreateKernel(const NodeDef& ndef,
401 const FunctionLibraryDefinition* lib_def,
402 OpKernel** kernel);
403 Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
404 const FunctionLibraryDefinition* lib_def,
405 FunctionBody** fbody);
406 Status CreateItem(Item** item);
407 Status GetOrCreateItem(LocalHandle local_handle, Item** item);
408 Status InstantiateSymbolicGradient(const NameAttrList& func,
409 const FunctionLibraryDefinition* lib_def,
410 FunctionBody** g_body);
411 bool IsLocalTarget(const InstantiateOptions& options);
412 AttrValueMap FixAttrs(const AttrSlice& attrs);
413 void RunRemote(const Options& opts, Handle handle,
414 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
415 Item* item, DoneCallback done);
416
417 void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
418 CallFrameInterface* frame,
419 Executor::Args* exec_args);
420
421 TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
422 };
423
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * default_thread_pool,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,ProcessFunctionLibraryRuntime * parent)424 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
425 const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
426 const FunctionLibraryDefinition* lib_def,
427 thread::ThreadPool* default_thread_pool,
428 const OptimizerOptions& optimizer_options,
429 CustomKernelCreator custom_kernel_creator,
430 ProcessFunctionLibraryRuntime* parent)
431 : device_mgr_(dmgr),
432 device_(device),
433 env_(env),
434 graph_def_version_(graph_def_version),
435 base_lib_def_(lib_def),
436 optimizer_(optimizer_options),
437 custom_kernel_creator_(std::move(custom_kernel_creator)),
438 default_runner_(nullptr),
439 device_name_(device_ == nullptr
440 ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
441 : device_->name()),
442 next_handle_(0),
443 parent_(parent) {
444 get_func_sig_ = [this](const string& op, const OpDef** sig) {
445 return base_lib_def_->LookUpOpDef(op, sig);
446 };
447 create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
448 return CreateKernel(ndef, kernel);
449 };
450 thread::ThreadPool* pool = nullptr;
451 if (device_ != nullptr) {
452 pool = device_->tensorflow_device_thread_pool();
453 }
454 if (pool == nullptr) {
455 pool = default_thread_pool;
456 }
457 if (pool != nullptr) {
458 default_runner_ = [pool](Executor::Args::Closure c) {
459 pool->Schedule(std::move(c));
460 };
461 }
462 }
463
~FunctionLibraryRuntimeImpl()464 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {}
465
466 // An asynchronous op kernel which executes an instantiated function
467 // defined in a library.
468 class CallOp : public AsyncOpKernel {
469 public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)470 CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
471 : AsyncOpKernel(ctx), handle_(handle) {}
472
~CallOp()473 ~CallOp() override {
474 // TODO(iga): Release the cached handle_
475 }
476
ComputeAsync(OpKernelContext * ctx,DoneCallback done)477 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
478 FunctionLibraryRuntime* lib = ctx->function_library();
479 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
480 errors::Internal("No function library is provided."),
481 done);
482 FunctionLibraryRuntime::Options opts;
483 opts.step_id = ctx->step_id();
484 opts.rendezvous = ctx->rendezvous();
485 opts.cancellation_manager = ctx->cancellation_manager();
486 opts.step_container = ctx->step_container();
487 opts.stats_collector = ctx->stats_collector();
488 opts.runner = ctx->runner();
489 opts.collective_executor = ctx->collective_executor();
490 std::vector<Tensor> args;
491 args.reserve(ctx->num_inputs());
492 for (int i = 0; i < ctx->num_inputs(); ++i) {
493 args.push_back(ctx->input(i));
494 }
495 std::vector<Tensor>* rets = new std::vector<Tensor>;
496 lib->Run(opts, handle_, args, rets,
497 [ctx, done, rets](const Status& status) {
498 if (!status.ok()) {
499 ctx->SetStatus(status);
500 } else {
501 const int ret_size = static_cast<int>(rets->size());
502 CHECK_EQ(ret_size, ctx->num_outputs());
503 for (int i = 0; i < ret_size; ++i) {
504 ctx->set_output(i, (*rets)[i]);
505 }
506 }
507 delete rets;
508 done();
509 });
510 }
511
512 private:
513 FunctionLibraryRuntime::Handle handle_;
514
515 TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
516 };
517
GetFunctionBody(Handle h)518 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
519 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
520 if (local_handle == kInvalidLocalHandle) {
521 LOG(ERROR) << "Could not find Handle: " << h
522 << " on device: " << device_name_;
523 return nullptr;
524 }
525
526 tf_shared_lock l(mu_);
527 auto iter = items_.find(local_handle);
528 CHECK(iter != items_.end());
529 return iter->second->func_graph;
530 }
531
CreateKernel(const NodeDef & ndef,OpKernel ** kernel)532 Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
533 OpKernel** kernel) {
534 return CreateKernel(ndef, base_lib_def_, kernel);
535 }
536
CreateKernel(const NodeDef & ndef,const FunctionLibraryDefinition * lib_def,OpKernel ** kernel)537 Status FunctionLibraryRuntimeImpl::CreateKernel(
538 const NodeDef& ndef, const FunctionLibraryDefinition* lib_def,
539 OpKernel** kernel) {
540 // If a custom kernel creator is given, try that.
541 Status s;
542 if (custom_kernel_creator_) {
543 std::unique_ptr<OpKernel> ret;
544 s = custom_kernel_creator_(this, ndef, &ret);
545 if (s.ok()) {
546 *kernel = ret.release();
547 return s;
548 } else {
549 VLOG(2) << "Custom creator error: " << s;
550 // Falls through.
551 s = Status::OK();
552 }
553 }
554
555 if (lib_def->Find(ndef.op()) == nullptr) {
556 // A primitive operation. Creates the registered kernel.
557 return CreateNonCachedKernel(device_, this, ndef, graph_def_version_,
558 kernel);
559 }
560
561 // Try to instantiate this function for the func/attr. Maybe it's
562 // cached already.
563 InstantiateOptions options;
564 if (lib_def != base_lib_def_) {
565 options.overlay_lib = lib_def;
566 }
567 Handle handle;
568 TF_RETURN_IF_ERROR(
569 Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle));
570
571 const FunctionBody* fbody = GetFunctionBody(handle);
572 CHECK_NOTNULL(fbody);
573
574 // TODO(zhifengc): For now, we assume int32 and resources are always on host
575 // memory and other types are always on device memory. We should do type
576 // inference over function body to derive the correct input/output memory
577 // types.
578 MemoryTypeVector input_memory_types;
579 for (const auto& t : fbody->arg_types) {
580 input_memory_types.push_back(MTypeFromDType(t));
581 }
582 MemoryTypeVector output_memory_types;
583 for (const auto& t : fbody->ret_types) {
584 output_memory_types.push_back(MTypeFromDType(t));
585 }
586
587 // Constructs a CallOp kernel for running the instantiated function.
588 auto device_type = DeviceType(device_->attributes().device_type());
589 OpKernelConstruction construction(
590 device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
591 &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
592 fbody->ret_types, output_memory_types, graph_def_version_, &s);
593 if (s.ok()) {
594 *kernel = new CallOp(handle, &construction);
595 }
596 return s;
597 }
598
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,FunctionBody ** fbody)599 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
600 const FunctionDef& fdef, AttrSlice attrs,
601 const FunctionLibraryDefinition* lib_def, FunctionBody** fbody) {
602 if (lib_def == base_lib_def_) {
603 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
604 } else {
605 auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
606 return lib_def->LookUpOpDef(op, sig);
607 };
608 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
609 }
610 }
611
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,FunctionBody ** g_body)612 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
613 const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
614 FunctionBody** g_body) {
615 const FunctionDef* fdef = lib_def->Find(func.name());
616 if (fdef == nullptr) {
617 // f is a primitive op.
618 gradient::Creator creator;
619 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
620 if (creator == nullptr) {
621 return errors::InvalidArgument("No gradient is defined for ",
622 func.name());
623 }
624 FunctionDef grad_fdef;
625 // TODO(josh11b): Should filter out the attrs from func that aren't used
626 // by the gradient function.
627 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
628 TF_RETURN_IF_ERROR(
629 FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
630 } else {
631 // f is a user-defined function.
632 InstantiateOptions options;
633 if (lib_def != base_lib_def_) {
634 options.overlay_lib = lib_def;
635 }
636 Handle f_handle;
637 TF_RETURN_IF_ERROR(
638 Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
639 const FunctionBody* f_body = GetFunctionBody(f_handle);
640 CHECK_NOTNULL(f_body);
641 *g_body = SymbolicGradient(*f_body);
642 }
643 return Status::OK();
644 }
645
IsLocalTarget(const InstantiateOptions & options)646 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
647 const InstantiateOptions& options) {
648 if (device_ == nullptr) return true;
649 if (options.target.empty()) return true;
650 if (options.is_multi_device_function) return false;
651 Device* target_device;
652 if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
653 VLOG(1) << "Not instantiating function in FLR because failed to "
654 << "find device " << options.target << " in device manager";
655 return false;
656 }
657 if (target_device != device_) {
658 VLOG(1) << "Not instantiating function in FLR because target device "
659 << options.target
660 << " is different from FLR's device: " << device_->DebugString();
661 return false;
662 }
663 return true;
664 }
665
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)666 Status FunctionLibraryRuntimeImpl::Instantiate(
667 const string& function_name, AttrSlice attrs,
668 const InstantiateOptions& options, Handle* handle) {
669 if (!IsLocalTarget(options)) {
670 return parent_->Instantiate(function_name, attrs, options, handle);
671 }
672
673 // Since this is a local target, ensure that the local `device_name_` appears
674 // in the canonical key.
675 InstantiateOptions options_copy(options);
676 options_copy.target = device_name_;
677 const string key = Canonicalize(function_name, attrs, options_copy);
678
679 {
680 mutex_lock l(mu_);
681 *handle = parent_->GetHandle(key);
682 if (*handle != kInvalidHandle) {
683 FunctionLibraryRuntime::LocalHandle handle_on_device =
684 parent_->GetHandleOnDevice(device_name_, *handle);
685 if (handle_on_device == kInvalidLocalHandle) {
686 return errors::Internal("LocalHandle not found for handle ", *handle,
687 ".");
688 }
689 auto item_handle = items_.find(handle_on_device);
690 if (item_handle == items_.end()) {
691 return errors::Internal("LocalHandle ", handle_on_device,
692 " for handle ", *handle,
693 " not found in items.");
694 }
695 ++item_handle->second->instantiation_counter;
696 return Status::OK();
697 }
698 }
699
700 const FunctionLibraryDefinition* lib_def =
701 options.overlay_lib ? options.overlay_lib : base_lib_def_;
702 FunctionBody* fbody = nullptr;
703 if (function_name == kGradientOp) {
704 const AttrValue* f = attrs.Find(kFuncAttr);
705 if (f == nullptr) {
706 return errors::InvalidArgument("SymbolicGradient is missing attr: f");
707 }
708 const auto& func = f->func();
709 if (func.name() == kGradientOp) {
710 return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
711 }
712 const string grad = lib_def->FindGradient(func.name());
713 if (!grad.empty()) {
714 return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
715 }
716 TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
717 } else {
718 const FunctionDef* fdef = lib_def->Find(function_name);
719 if (fdef == nullptr) {
720 return errors::NotFound("Function ", function_name, " is not defined.");
721 }
722 TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
723 }
724
725 LocalHandle local_handle;
726 {
727 mutex_lock l(mu_);
728 *handle = parent_->GetHandle(key);
729 if (*handle != kInvalidHandle) {
730 delete fbody;
731 local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
732 ++items_[local_handle]->instantiation_counter;
733 } else {
734 *handle = parent_->AddHandle(key, device_name_, next_handle_);
735 Item* item = new Item;
736 item->func_graph = fbody;
737 item->overlay_lib = options.overlay_lib;
738 item->instantiation_counter = 1;
739 item->executor_type = ExecutorType(options, attrs);
740 if (options.overlay_lib) {
741 item->overlay_flr =
742 new FunctionLibraryRuntimeOverlay(this, options.overlay_lib);
743 }
744 local_handle = next_handle_++;
745 items_.emplace(local_handle, std::unique_ptr<Item>(item));
746 }
747 }
748
749 if (options.create_kernels_eagerly) {
750 Item* item;
751 TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
752 }
753
754 return Status::OK();
755 }
756
ReleaseHandle(Handle handle)757 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
758 LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
759 if (h == kInvalidLocalHandle) {
760 return parent_->ReleaseHandle(handle);
761 }
762
763 std::unique_ptr<Item> item_to_delete;
764 Status parent_status;
765 {
766 mutex_lock l(mu_);
767 auto it = items_.find(h);
768 if (it == items_.end()) {
769 return errors::Internal(
770 "Inconsistent FunctionLibraryRuntime. Expected to find an item for "
771 "handle ",
772 h, " but found none");
773 }
774 std::unique_ptr<Item>& item = it->second;
775 --item->instantiation_counter;
776 if (item->instantiation_counter == 0) {
777 // We don't simply erase h's item because that would trigger
778 // item destruction while holding mu_. Item destruction can
779 // trigger graph destruction. If the graph contains kernels like
780 // CallOp or PartitionCallOp, their destructors will release cached
781 // function handles, resulting in deadlock here.
782 item_to_delete = std::move(item);
783 items_.erase(h);
784 parent_status = parent_->RemoveHandle(handle);
785 }
786 }
787 return parent_status;
788 }
789
DumpGraph(StringPiece label,const Graph * g)790 void DumpGraph(StringPiece label, const Graph* g) {
791 // TODO(zhifengc): Change Graph to record #nodes.
792 VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
793 << g->num_edges();
794 if (VLOG_IS_ON(2)) {
795 for (const auto& line : str_util::Split(DebugString(g), '\n')) {
796 VLOG(2) << "|| " << line;
797 }
798 }
799 }
800
OptimizeGraph(FunctionLibraryRuntime * lib,std::unique_ptr<Graph> * g,const GraphOptimizer::Options & graph_optimizer_options)801 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
802 const GraphOptimizer::Options& graph_optimizer_options) {
803 OptimizerOptions opts;
804 opts.set_do_common_subexpression_elimination(true);
805 opts.set_do_function_inlining(true);
806 opts.set_do_constant_folding(true);
807 GraphOptimizer optimizer(opts);
808 optimizer.Optimize(lib, lib->env(), lib->device(), g,
809 graph_optimizer_options);
810 }
811
OptimizeGraph(FunctionLibraryRuntime * lib,std::unique_ptr<Graph> * g)812 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
813 OptimizeGraph(lib, g, GraphOptimizer::Options());
814 }
815
816 namespace {
817 // Removes all stateless nodes that do not contribute to a return
818 // value from the function body. Unlike `RemoveDeadNodes()`, which is
819 // triggered by `OptimizerOptions.do_function_inlining`, this pass
820 // ignores the SINK node, from which (by definition) all nodes are
821 // reverse reachable, and preserves all nodes that are reachable from
822 // control output nodes.
823 //
824 // TODO(ezhulenev, skyewm): Function body should not have special treatment of
825 // stateful ops, graph should encode nodes that must execute with `control_ret`
826 // and `control_output`.
PruneFunctionBody(const FunctionDef & fdef,Graph * g)827 void PruneFunctionBody(const FunctionDef& fdef, Graph* g) {
828 VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name();
829
830 // `control_ret` nodes must be always executed.
831 std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes;
832 for (const auto& control_ret : fdef.control_ret()) {
833 control_ret_nodes.insert(control_ret.second);
834 }
835
836 std::unordered_set<const Node*> nodes;
837 for (auto n : g->nodes()) {
838 // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
839 // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
840 // specifically exclude them as seeds, to avoid unconditionally executing
841 // unused argument nodes (e.g. in a function like `lambda x, y: y`).
842 // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
843 // still needed. It would be preferable to prune entire loops and/or
844 // conditionals if they are not used in the graph.
845 if (n->IsControlFlow() ||
846 (n->op_def().is_stateful() && n->type_string() != kArgOp) ||
847 (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) {
848 nodes.insert(n);
849 }
850 }
851 bool changed = PruneForReverseReachability(g, std::move(nodes));
852 if (changed) {
853 FixupSourceAndSinkEdges(g);
854 }
855 }
856 } // namespace
857
CreateItem(Item ** item)858 Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
859 const FunctionBody* fbody;
860 const FunctionLibraryDefinition* lib_def;
861 string executor_type;
862 {
863 tf_shared_lock l(mu_);
864 fbody = (*item)->func_graph;
865 lib_def = (*item)->overlay_lib;
866 executor_type = (*item)->executor_type;
867 }
868 if (!lib_def) {
869 lib_def = base_lib_def_;
870 }
871 std::unique_ptr<Graph> g(new Graph(lib_def));
872 CopyGraph(*fbody->graph, g.get());
873
874 PruneFunctionBody(fbody->fdef, g.get());
875 optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
876 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
877 device()->name(), g.get()));
878
879 // Creates an executor based on the g. This must be done without
880 // holding mu_ because create_kernel_ calls back into the library.
881 LocalExecutorParams params;
882 params.device = device_;
883 params.function_library =
884 (*item)->overlay_flr
885 ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr)
886 : static_cast<FunctionLibraryRuntime*>(this);
887 if (lib_def == base_lib_def_) {
888 params.create_kernel = create_kernel_;
889 } else {
890 params.create_kernel = [this, lib_def](const NodeDef& ndef,
891 OpKernel** kernel) {
892 return CreateKernel(ndef, lib_def, kernel);
893 };
894 }
895 params.delete_kernel = [](OpKernel* kernel) {
896 DeleteNonCachedKernel(kernel);
897 };
898 Graph* graph = g.get();
899 std::unique_ptr<Executor> exec;
900 TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec));
901 {
902 // Guard item since it is already inserted in items_.
903 mutex_lock l(mu_);
904 if ((*item)->exec == nullptr) {
905 (*item)->graph = graph;
906 (*item)->exec = exec.release();
907 }
908 }
909 return Status::OK();
910 }
911
GetOrCreateItem(LocalHandle local_handle,Item ** item)912 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
913 Item** item) {
914 {
915 tf_shared_lock l(mu_);
916 auto iter = items_.find(local_handle);
917 if (iter == items_.end()) {
918 return errors::Internal("Local function handle ", local_handle,
919 " is not valid. Likely an internal error.");
920 }
921 *item = iter->second.get();
922 if ((*item)->exec != nullptr) {
923 return Status::OK();
924 }
925 }
926 // NOTE: We need to call CreateItem out of mu_ because creating an
927 // executor needs to call CreateKernel.
928 return CreateItem(item);
929 }
930
ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options & run_opts,CallFrameInterface * frame,Executor::Args * exec_args)931 void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
932 const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
933 Executor::Args* exec_args) {
934 // Inherit the step_id from the caller.
935 exec_args->step_id = run_opts.step_id;
936 exec_args->rendezvous = run_opts.rendezvous;
937 exec_args->stats_collector = run_opts.stats_collector;
938 exec_args->cancellation_manager = run_opts.cancellation_manager;
939 exec_args->step_container = run_opts.step_container;
940 if (run_opts.runner) {
941 exec_args->runner = *run_opts.runner;
942 } else {
943 exec_args->runner = default_runner_;
944 }
945 exec_args->collective_executor = run_opts.collective_executor;
946 exec_args->call_frame = frame;
947 }
948
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Item * item,DoneCallback done)949 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
950 gtl::ArraySlice<Tensor> args,
951 std::vector<Tensor>* rets,
952 Item* item, DoneCallback done) {
953 string target_device = parent_->GetDeviceName(handle);
954 string source_device = opts.source_device;
955 Rendezvous* rendezvous = opts.rendezvous;
956 DeviceContext* device_context;
957 Status s = parent_->GetDeviceContext(target_device, &device_context);
958 if (!s.ok()) {
959 done(s);
960 return;
961 }
962 int64 src_incarnation, target_incarnation;
963 s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
964 s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
965 if (!s.ok()) {
966 done(s);
967 return;
968 }
969
970 const FunctionBody* fbody = GetFunctionBody(handle);
971 FunctionCallFrame* frame =
972 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
973 Executor::Args* exec_args = new Executor::Args;
974 ExecutorArgsFromOptions(opts, frame, exec_args);
975
976 std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
977 args_alloc_attrs.reserve(fbody->arg_types.size());
978 rets_alloc_attrs.reserve(fbody->ret_types.size());
979 // Note: Functions assume that int32's are always on host memory.
980 for (const auto& arg_type : fbody->arg_types) {
981 AllocatorAttributes arg_alloc_attrs;
982 if (MTypeFromDType(arg_type) == HOST_MEMORY) {
983 arg_alloc_attrs.set_on_host(true);
984 }
985 args_alloc_attrs.push_back(arg_alloc_attrs);
986 }
987 for (const auto& ret_type : fbody->ret_types) {
988 AllocatorAttributes ret_alloc_attrs;
989 if (MTypeFromDType(ret_type) == HOST_MEMORY) {
990 ret_alloc_attrs.set_on_host(true);
991 }
992 rets_alloc_attrs.push_back(ret_alloc_attrs);
993 }
994
995 bool allow_dead_tensors = opts.allow_dead_tensors;
996
997 // The ProcFLR sends the arguments to the function from the source_device to
998 // the target_device. So here we receive those arguments. Similarly, when the
999 // computation is done and stored in *rets, we send the return values back
1000 // to the source_device (caller) so that the ProcFLR can receive them later.
1001 std::vector<Tensor>* remote_args = new std::vector<Tensor>;
1002 ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
1003 source_device, target_device, "arg_", src_incarnation, args.size(),
1004 device_context, args_alloc_attrs, rendezvous, remote_args,
1005 [frame, remote_args, item, source_device, target_device,
1006 target_incarnation, rendezvous, device_context, rets, done, exec_args,
1007 rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1008 Status s = status;
1009 if (s.ok()) {
1010 s = frame->SetArgs(*remote_args);
1011 }
1012 if (!s.ok()) {
1013 delete frame;
1014 delete remote_args;
1015 delete exec_args;
1016 done(s);
1017 return;
1018 }
1019 item->exec->RunAsync(
1020 *exec_args,
1021 [frame, rets, done, source_device, target_device,
1022 target_incarnation, rendezvous, device_context, remote_args,
1023 rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1024 Status s = status;
1025 if (s.ok()) {
1026 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1027 }
1028 delete frame;
1029 if (!s.ok()) {
1030 delete remote_args;
1031 done(s);
1032 return;
1033 }
1034 s = ProcessFunctionLibraryRuntime::SendTensors(
1035 target_device, source_device, "ret_", target_incarnation,
1036 *rets, device_context, rets_alloc_attrs, rendezvous);
1037 delete remote_args;
1038 done(s);
1039 });
1040 delete exec_args;
1041 });
1042 }
1043
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)1044 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1045 gtl::ArraySlice<Tensor> args,
1046 std::vector<Tensor>* rets,
1047 DoneCallback done) {
1048 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1049 done(errors::Cancelled(""));
1050 return;
1051 }
1052 Options run_opts = opts;
1053 if (opts.create_rendezvous) {
1054 Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
1055 run_opts.rendezvous = rendezvous;
1056 run_opts.create_rendezvous = false;
1057 done = [done, rendezvous](const Status& status) {
1058 rendezvous->Unref();
1059 done(status);
1060 };
1061 }
1062
1063 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1064 if (local_handle == kInvalidLocalHandle) {
1065 parent_->Run(run_opts, handle, args, rets, done);
1066 return;
1067 }
1068
1069 if (run_opts.runner == nullptr) {
1070 run_opts.runner = &default_runner_;
1071 }
1072 DCHECK(run_opts.runner != nullptr);
1073
1074 Item* item = nullptr;
1075 Status s = GetOrCreateItem(local_handle, &item);
1076 if (!s.ok()) {
1077 done(s);
1078 return;
1079 }
1080
1081 if (run_opts.remote_execution) {
1082 // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
1083 RunRemote(run_opts, handle, args, rets, item, done);
1084 return;
1085 }
1086
1087 const FunctionBody* fbody = GetFunctionBody(handle);
1088 FunctionCallFrame* frame =
1089 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1090 s = frame->SetArgs(args);
1091 if (!s.ok()) {
1092 delete frame;
1093 done(s);
1094 return;
1095 }
1096
1097 Executor::Args exec_args;
1098 ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1099
1100 bool allow_dead_tensors = run_opts.allow_dead_tensors;
1101 item->exec->RunAsync(
1102 // Executor args
1103 exec_args,
1104 // Done callback.
1105 [frame, rets, done, allow_dead_tensors](const Status& status) {
1106 Status s = status;
1107 if (s.ok()) {
1108 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1109 }
1110 delete frame;
1111 done(s);
1112 });
1113 }
1114
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)1115 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1116 CallFrameInterface* frame,
1117 DoneCallback done) {
1118 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1119 done(errors::Cancelled(""));
1120 return;
1121 }
1122 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1123 if (local_handle == kInvalidLocalHandle || opts.remote_execution) {
1124 done(errors::Unimplemented("Remote calling with CallFrameInterface"));
1125 return;
1126 }
1127
1128 Options run_opts = opts;
1129 if (opts.create_rendezvous) {
1130 Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
1131 run_opts.rendezvous = rendezvous;
1132 run_opts.create_rendezvous = false;
1133 done = std::bind(
1134 [rendezvous](DoneCallback done,
1135 // Begin unbound arguments.
1136 const Status& status) {
1137 rendezvous->Unref();
1138 done(status);
1139 },
1140 std::move(done), std::placeholders::_1);
1141 }
1142
1143 Item* item = nullptr;
1144 Status s = GetOrCreateItem(local_handle, &item);
1145 if (!s.ok()) {
1146 done(s);
1147 return;
1148 }
1149 if (run_opts.runner == nullptr) {
1150 run_opts.runner = &default_runner_;
1151 }
1152 DCHECK(run_opts.runner != nullptr);
1153
1154 Executor::Args exec_args;
1155 ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1156 item->exec->RunAsync(exec_args, std::move(done));
1157 }
1158
IsStateful(const string & func)1159 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
1160 const OpDef* op_def;
1161 const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
1162 return s.ok() && op_def->is_stateful();
1163 }
1164
DebugString(Handle handle)1165 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
1166 Item* item = nullptr;
1167 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1168 Status s = GetOrCreateItem(local_handle, &item);
1169 if (s.ok()) {
1170 return tensorflow::DebugString(item->graph);
1171 } else {
1172 return s.ToString();
1173 }
1174 }
1175
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr)1176 Status FunctionLibraryRuntimeImpl::Clone(
1177 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1178 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1179 FunctionLibraryRuntime** out_flr) {
1180 TF_RETURN_IF_ERROR(
1181 parent_->Clone(env_, graph_def_version_, optimizer_.options(),
1182 custom_kernel_creator_, out_lib_def, out_pflr));
1183 *out_flr = (*out_pflr)->GetFLR(device_->name());
1184 if (out_flr != nullptr) {
1185 return Status::OK();
1186 } else {
1187 return errors::Internal("Cloning FunctionLibraryRuntime failed.");
1188 }
1189 }
1190
1191 namespace {
1192
1193 struct CustomCreatorSingleton {
1194 mutex mu;
1195 CustomKernelCreator custom_creator = nullptr;
1196
Settensorflow::__anon924a102b0e11::CustomCreatorSingleton1197 void Set(CustomKernelCreator cb) {
1198 mutex_lock l(mu);
1199 custom_creator = std::move(cb);
1200 }
1201
Gettensorflow::__anon924a102b0e11::CustomCreatorSingleton1202 CustomKernelCreator Get() {
1203 mutex_lock l(mu);
1204 return custom_creator;
1205 }
1206 };
1207
GetCustomCreatorSingleton()1208 CustomCreatorSingleton* GetCustomCreatorSingleton() {
1209 static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
1210 return ccs;
1211 }
1212
1213 } // namespace
1214
RegisterDefaultCustomKernelCreator(CustomKernelCreator cb)1215 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
1216 GetCustomCreatorSingleton()->Set(std::move(cb));
1217 }
1218
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,ProcessFunctionLibraryRuntime * parent)1219 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1220 const DeviceMgr* device_mgr, Env* env, Device* device,
1221 int graph_def_version, const FunctionLibraryDefinition* lib_def,
1222 thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
1223 CustomKernelCreator custom_kernel_creator,
1224 ProcessFunctionLibraryRuntime* parent) {
1225 return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
1226 device_mgr, env, device, graph_def_version, lib_def, thread_pool,
1227 optimizer_options, std::move(custom_kernel_creator), parent));
1228 }
1229
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,ProcessFunctionLibraryRuntime * parent)1230 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1231 const DeviceMgr* device_mgr, Env* env, Device* device,
1232 int graph_def_version, const FunctionLibraryDefinition* lib_def,
1233 thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
1234 ProcessFunctionLibraryRuntime* parent) {
1235 return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
1236 lib_def, thread_pool, optimizer_options,
1237 GetCustomCreatorSingleton()->Get(), parent);
1238 }
1239
RemoveDeadNodes(Graph * g)1240 bool RemoveDeadNodes(Graph* g) {
1241 VLOG(2) << "Removing dead nodes";
1242 std::unordered_set<const Node*> nodes;
1243 for (auto n : g->nodes()) {
1244 if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
1245 n->op_def().is_stateful()) {
1246 nodes.insert(n);
1247 }
1248 }
1249 return PruneForReverseReachability(g, std::move(nodes));
1250 }
1251
1252 namespace {
1253 // If 'edges' contains only 1 non-control edge, returns it. Otherwise,
1254 // returns a nullptr.
GetTheOnlyDataEdge(const EdgeSet & edges)1255 const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
1256 const Edge* ret = nullptr;
1257 for (const Edge* e : edges) {
1258 if (e->IsControlEdge() || ret) {
1259 // Don't touch it if there is a control edge.
1260 return nullptr;
1261 }
1262 if (IsRefType(e->src()->output_type(e->src_output()))) {
1263 // Don't touch it if the identity node is effectively de-reffing
1264 // a ref.
1265 return nullptr;
1266 }
1267 if (IsRecv(e->src()) || IsSwitch(e->src())) {
1268 // Don't touch it if the identity is introduced for control flow.
1269 // Recv disables all its successors if it receives a dead signal.
1270 // When Recv has an outgoing control edge, the current executor
1271 // would not disable the destination. The current solution (see
1272 // graph_partition.cc) is to add an identity after Recv and change
1273 // the control edge to be from this identity node. So the identity
1274 // can't be removed.
1275 return nullptr;
1276 }
1277 ret = e;
1278 }
1279 return ret;
1280 }
1281 } // end namespace
1282
RemoveIdentityNodes(Graph * g)1283 bool RemoveIdentityNodes(Graph* g) {
1284 VLOG(2) << "Removing identity nodes";
1285 bool removed_any = false;
1286 gtl::InlinedVector<Node*, 8> matches;
1287 for (Node* n : g->nodes()) {
1288 if (!n->IsIdentity()) continue;
1289 if (!GetTheOnlyDataEdge(n->in_edges())) continue;
1290
1291 // Some identity nodes are used as sink nodes to give names to output
1292 // tensors. These nodes are not going to be executed unless they are in the
1293 // fetch set. But if they are in the fetch set we don't want to remove them.
1294 if (n->out_edges().empty()) continue;
1295
1296 matches.push_back(n);
1297 }
1298 if (!matches.empty()) {
1299 for (Node* n : matches) {
1300 const Edge* in = GetTheOnlyDataEdge(n->in_edges());
1301 for (const Edge* out : n->out_edges()) {
1302 if (out->IsControlEdge()) {
1303 g->AddControlEdge(in->src(), out->dst());
1304 } else {
1305 g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
1306 }
1307 }
1308 VLOG(2) << "Remove Identity: " << n->DebugString();
1309 g->RemoveNode(n);
1310 removed_any = true;
1311 }
1312 }
1313 return removed_any;
1314 }
1315
RemoveListArrayConverter(Graph * g)1316 bool RemoveListArrayConverter(Graph* g) {
1317 VLOG(2) << "Removing list array converter";
1318 gtl::InlinedVector<Node*, 8> matches;
1319 for (Node* n : g->nodes()) {
1320 if ((n->type_string() == "_ListToArray") ||
1321 (n->type_string() == "_ArrayToList")) {
1322 matches.push_back(n);
1323 }
1324 }
1325 bool removed_any = false;
1326 if (!matches.empty()) {
1327 for (Node* n : matches) {
1328 if (n->num_inputs() != n->num_outputs()) {
1329 continue; // Not expected. Skip.
1330 }
1331 gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
1332
1333 const auto no_op = [&](StringPiece name) {
1334 return AddNoOp(absl::StrCat(n->name(), "/", name), g);
1335 };
1336
1337 const auto identity = [&](StringPiece name, Endpoint input) {
1338 return AddIdentity(absl::StrCat(n->name(), "/", name), g, input);
1339 };
1340
1341 // Process input edges first.
1342 Node* input_control_node = nullptr;
1343 for (const Edge* e : n->in_edges()) {
1344 if (e->IsControlEdge()) {
1345 if (input_control_node == nullptr) {
1346 // If node "n" has any control dependencies, adds a no-op
1347 // node (input_control_node) which the additional Identity
1348 // nodes depends on and the input_control_node depends on
1349 // the node "n"s control dependencies.
1350 input_control_node = no_op("input_control_node");
1351 }
1352 g->AddControlEdge(e->src(), input_control_node);
1353 } else {
1354 const int index = e->dst_input();
1355 Node** id_node = &identity_nodes[index];
1356 if (*id_node != nullptr) {
1357 LOG(ERROR)
1358 << "RemoveListArrayConverter unexpected duplicated input: "
1359 << e->dst_input();
1360 return removed_any;
1361 }
1362 *id_node = identity("input", {e->src(), e->src_output()});
1363 }
1364 }
1365
1366 // If node "n" has any control dependencies, the added identity
1367 // nodes should have control dependencies on input_control_node.
1368 if (input_control_node != nullptr) {
1369 for (Node* id : identity_nodes) {
1370 g->AddControlEdge(input_control_node, id);
1371 }
1372 }
1373
1374 Node* output_control_node = nullptr;
1375 for (const Edge* e : n->out_edges()) {
1376 if (e->IsControlEdge()) {
1377 if (output_control_node == nullptr) {
1378 // If node "n" is control-depended upon by other nodes,
1379 // adds a no-op node (output_control_node) which those
1380 // nodes will depend on and output_control_node depends on
1381 // all Identity nodes.
1382 output_control_node = no_op("output_control_node");
1383 }
1384 g->AddControlEdge(output_control_node, e->dst());
1385 } else {
1386 Node* id_node = identity_nodes[e->src_output()];
1387 if (id_node == nullptr) {
1388 LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
1389 << e->src_output();
1390 return removed_any;
1391 }
1392 CHECK(id_node);
1393 g->AddEdge(id_node, 0, e->dst(), e->dst_input());
1394 }
1395 }
1396
1397 // If any nodes have control dependencies on node "n", those
1398 // nodes should have control dependencies on
1399 // output_control_node.
1400 if (output_control_node != nullptr) {
1401 for (Node* id : identity_nodes) {
1402 g->AddControlEdge(id, output_control_node);
1403 }
1404 }
1405
1406 g->RemoveNode(n);
1407 removed_any = true;
1408 }
1409 }
1410 return removed_any;
1411 }
1412
InstantiateFunctionCall(const NodeDef & call_def,FunctionLibraryRuntime & flr,FunctionLibraryRuntime::Handle * handle)1413 Status InstantiateFunctionCall(const NodeDef& call_def,
1414 FunctionLibraryRuntime& flr,
1415 FunctionLibraryRuntime::Handle* handle) {
1416 const string* func_name;
1417 AttrSlice attrs;
1418
1419 NameAttrList func;
1420 if (call_def.op() == "PartitionedCall" ||
1421 call_def.op() == "StatefulPartitionedCall") {
1422 TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", &func));
1423 func_name = &func.name();
1424 attrs = AttrSlice(&func.attr());
1425 } else {
1426 func_name = &call_def.op();
1427 attrs = AttrSlice(call_def);
1428 }
1429
1430 return flr.Instantiate(*func_name, attrs, handle);
1431 }
1432
1433 namespace {
1434
ValidateNoInline(const FunctionBody * fbody)1435 Status ValidateNoInline(const FunctionBody* fbody) {
1436 const auto attr = AttrSlice(&fbody->fdef.attr());
1437 bool noinline = false;
1438 if (GetNodeAttr(attr, kNoInlineAttr, &noinline).ok() && noinline) {
1439 return errors::InvalidArgument(
1440 "Can't inline function marked with '_noinline'");
1441 }
1442 return Status::OK();
1443 }
1444
1445 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
1446
1447 } // namespace
1448
DebugString() const1449 string InlineFunctionBodyOptions::DebugString() const {
1450 return absl::StrCat("ignore_noinline=", ignore_noinline ? "true" : "false",
1451 ", override_device=", override_device ? "true" : "false",
1452 ", output_control_src=",
1453 output_control_src == OutputControlSrc::kDataOutputs
1454 ? "DataOutputs"
1455 : "ControlOutputs");
1456 }
1457
ValidateInlining(const Node * node,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)1458 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
1459 const InlineFunctionBodyOptions& options) {
1460 // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
1461 // that all side-effectful ops will be executed after inlining. See Grappler
1462 // function_optimizer for details. Unify all function inlining mechanism.
1463 // Do not inline if `!fbody->control_ret_nodes.empty()`.
1464
1465 const auto num_node_inputs = static_cast<size_t>(node->num_inputs());
1466 const auto num_node_outputs = static_cast<size_t>(node->num_outputs());
1467
1468 if (num_node_inputs != fbody->arg_types.size() ||
1469 num_node_inputs != fbody->arg_nodes.size()) {
1470 return errors::InvalidArgument(
1471 "Node inputs do not match function arguments: inputs=", num_node_inputs,
1472 " arg_types=", fbody->arg_types.size(),
1473 " arg_nodes=", fbody->arg_nodes.size());
1474 }
1475
1476 if (num_node_outputs != fbody->ret_types.size() ||
1477 num_node_outputs != fbody->ret_nodes.size()) {
1478 return errors::InvalidArgument(
1479 "Node outputs do not match function returns: outputs=",
1480 num_node_outputs, " ret_types=", fbody->ret_types.size(),
1481 " ret_nodes=", fbody->ret_nodes.size());
1482 }
1483
1484 for (int i = 0; i < node->num_inputs(); ++i) {
1485 if (node->input_type(i) != fbody->arg_types[i]) {
1486 return errors::InvalidArgument(
1487 "Node input type doesn't match function argument type: ",
1488 node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i);
1489 }
1490 }
1491 for (int i = 0; i < node->num_outputs(); ++i) {
1492 if (node->output_type(i) != fbody->ret_types[i]) {
1493 return errors::InvalidArgument(
1494 "Node output type doesn't match function return type: ",
1495 node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i);
1496 }
1497 }
1498
1499 if (!options.ignore_noinline) {
1500 TF_RETURN_IF_ERROR(ValidateNoInline(fbody));
1501 }
1502
1503 return Status::OK();
1504 }
1505
1506 // Function inlining must preserve function execution semantics with regards to
1507 // side-effects visibility. Tensorflow in Eager mode has an automatic control
1508 // dependencies tracking mechanism, which enforces well-defined execution order
1509 // of all side-effects. Any other frontend (e.g. Swift) must produce graphs
1510 // following the same rules, to ensure that function inlining works correctly.
1511 //
1512 // IMPORTANT: Currently we do not have a true notion of "side-effectful" node,
1513 // we assume that all stateful nodes might have side-effects, though it's not
1514 // true in practice, e.g. `ReadVariableOp` doesn't have an observable
1515 // side-effect.
1516 //
1517 // Automatic control dependency rules in Tensorflow 2.0 (python in eager mode):
1518 //
1519 // 1) When a function has a resource (DT_RESOURCE data type) input argument it
1520 // "captures" the mutable resource. This is implemented by automatically
1521 // adding a incoming control edge from the previous side-effectful op
1522 // touching that resource, and an outgoing control edge to the next
1523 // side-effectful op using the same resource. This serializes the mutations
1524 // of the resource to make graph execution deterministic.
1525 //
1526 // 2) All stateful ops inside a function body are guaranteed to execute in
1527 // program order, this is achieved by adding control edges between stateful
1528 // ops at graph construction time. Stateful ops (or ops that must execute)
1529 // should be in the function control return set. Having a data edge to the
1530 // regular function output might be not enough, because after function
1531 // inlining it might happen that data output is unused.
1532 //
1533 // 3) Furthermore, all ops accepting the same resource as an input are
1534 // guaranteed to run in program order. This is also done by adding control
1535 // edges at graph construction time. The last op touching the resource
1536 // must be in a control return set, which will guarantee that all side
1537 // effects to the resource will happen before function completion.
1538 //
1539 // Function inlining must preserve side-effect visibility:
1540 //
1541 // 1) All side-effects to the captured resources, that happened before function
1542 // call must be visible to the function body nodes using that resources.
1543 //
1544 // 2) All side-effects to the captured resources, that happened inside function
1545 // body, must be visible to every op/function using that resource after the
1546 // function call completed.
1547 //
1548 // To guarantee that these properties are preserved after inlining we:
1549 //
1550 // 1) Create "input_control_node" NoOp. Function call node incoming control
1551 // edges will be forwarded *to* this node. Function inputs (Identity nodes)
1552 // will have a control edge *from* this node. If function body has nodes
1553 // without inputs, they will have a control edge *from* this node.
1554 //
1555 // 2) Create "output_control_node" NoOp. All nodes that have incoming control
1556 // edge *from* the function call node, will be forwarded to this node.
1557 //
1558 // We have two options for choosing which nodes will have a control edge *to*
1559 // the "output control node":
1560 // a) control returns (`control_ret` field in FunctionDef)
1561 // b) data returns (`ret` field in FunctionDef)
1562 //
1563 // We do a) for multi-device function calls in Tensorflow v2 and b)
1564 // for the rest for compatibility with Tensorflow v1.
1565 //
1566 // Following the automatic control dependencies tracking rules, a node that
1567 // has an incoming control edge from the function call node is dependent on
1568 // the side-effects happening inside the function body. The output control
1569 // node will guarantee side-effects execution order.
1570 //
1571 // If function call node doesn't have an outgoing control edge, it means that
1572 // no one is interested in observing side-effects that might have happened.
1573 //
1574 // Function inlining might leave the graph in partially-placed state. Function
1575 // inlining caller must call Placer to guarantee that all nodes are placed.
1576 //
1577 // Function inlining with `options.override_device=true` will leave graph in
1578 // fully placed state, by overriding all inlined nodes devices with the caller
1579 // node device, but it will make functions always single-device. These functions
1580 // after inlining will not be able to handle resources on multiple devices. This
1581 // is currently acceptable for XLA use cases (XLA cluster is always executed on
1582 // a single device).
1583 //
1584 // TODO(ezhulenev): Documentation above is ahead of implementation below.
InlineFunctionBody(const FunctionLibraryDefinition & flib_def,Graph * g,Node * caller,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)1585 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
1586 Node* caller, const FunctionBody* fbody,
1587 const InlineFunctionBodyOptions& options) {
1588 VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
1589 << options.DebugString() << "]";
1590 VLOG(4) << "Inlined function definition: " << DebugString(fbody->fdef);
1591
1592 Status validation = ValidateInlining(caller, fbody, options);
1593 if (!validation.ok()) {
1594 LOG(WARNING) << "Inlining mismatch: " << SummarizeNode(*caller) << " vs. "
1595 << DebugString(fbody->graph);
1596 return errors::Internal("Inlining mismatch: ", validation.error_message());
1597 }
1598
1599 // ------------------------------------------------------------------------ //
1600 // Helper functions to create `NoOp` and `Identity` nodes for auxiliary
1601 // control nodes and inlined function inputs and outputs.
1602
1603 // Add a NoOp node for function control inputs/outputs.
1604 const auto no_op = [&](StringPiece name) {
1605 Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g);
1606 node->set_requested_device(caller->def().device());
1607 return node;
1608 };
1609
1610 // Add an Identity node for function data inputs/outputs.
1611 const auto identity = [&](StringPiece name, Endpoint input) {
1612 return AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
1613 };
1614
1615 // ------------------------------------------------------------------------ //
1616 // Input edges. For data edges coming into "caller", we first compute the
1617 // <src>:<src_output> for the i-th input in "inputs".
1618 // If "caller" has any input control dependencies, we add a NoOp
1619 // node "input_control_node", which depends on "caller"'s control inputs.
1620 std::vector<Endpoint> inputs(caller->num_inputs());
1621 Node* input_control_node = nullptr;
1622 for (const Edge* e : caller->in_edges()) {
1623 if (e->IsControlEdge()) {
1624 if (input_control_node == nullptr) {
1625 input_control_node = no_op("input_control_node");
1626 }
1627 g->AddControlEdge(e->src(), input_control_node);
1628 } else {
1629 inputs[e->dst_input()] = {e->src(), e->src_output()};
1630 }
1631 }
1632
1633 // ------------------------------------------------------------------------ //
1634 // Duplicate fbody->graph into 'g'. First, we copy the nodes of
1635 // fbody->graph into 'g' except the source and sink nodes. We copy
1636 // edges among nodes in 'fbody->graph'.
1637 //
1638 // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
1639 // remember 'y' in node_map[x->id()].
1640 std::vector<Node*> node_map(fbody->graph->num_node_ids());
1641 for (Node* n : fbody->graph->op_nodes()) {
1642 NodeDef ndef = n->def();
1643 ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
1644 if (options.override_device || ndef.device().empty()) {
1645 ndef.set_device(caller->def().device());
1646 }
1647 for (auto& attr : *ndef.mutable_attr()) {
1648 if (attr.first == "_class") {
1649 attr.second.set_s(
1650 strings::StrCat(caller->name(), "/", attr.second.s()));
1651 }
1652 }
1653 Status added_node;
1654 Node* clone = g->AddNode(ndef, &added_node);
1655 if (options.override_device && !caller->assigned_device_name().empty()) {
1656 clone->set_assigned_device_name(caller->assigned_device_name());
1657 }
1658 TF_CHECK_OK(added_node);
1659 node_map[n->id()] = clone;
1660
1661 // If there is an input control node, and one of:
1662 // a) the node has no data or control inputs, or
1663 // b) the node is a function call or SymbolicGradient,
1664 // then add a control edge from the input control node to the clone.
1665 //
1666 // We must not execute any nodes if the original function call would not
1667 // have executed. This is especially critical when the function call is
1668 // inside a control-flow construct like tf.cond(). Case (a) ensures that
1669 // such nodes do not run.
1670 //
1671 // The purpose of case (b) is to ensure that instances of case (a) created
1672 // by further inlining steps also receive the control dependency.
1673 //
1674 // TODO(ezhulenev): If caller has no control inputs, should we add a control
1675 // edge from one of the inputs to ensure that function body node will
1676 // execute in correct frame?
1677 if (input_control_node) {
1678 bool has_inputs = absl::c_any_of(
1679 n->in_edges(), [](const Edge* e) { return !e->src()->IsSource(); });
1680 if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr ||
1681 clone->type_string() == kGradientOp) {
1682 g->AddControlEdge(input_control_node, clone);
1683 }
1684 }
1685 }
1686 for (const Edge* e : fbody->graph->edges()) {
1687 if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
1688 e->dst()->IsSink()) {
1689 continue;
1690 }
1691 Node* src_copy = node_map[e->src()->id()];
1692 Node* dst_copy = node_map[e->dst()->id()];
1693 g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1694 }
1695
1696 // ------------------------------------------------------------------------ //
1697 // Connect input edges.
1698 //
1699 // We create one Identity node for each input. Then, we connect inputs[i] to
1700 // the i-th identity node added. The nodes that previously connected
1701 // to the j-th output of i-th arg node are reconnected to the i-th
1702 // identity node.
1703 //
1704 // The added identity nodes depend on "input_control_node".
1705 for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
1706 Node* arg = node_map[fbody->arg_nodes[i]->id()];
1707 Node* n = identity("input", inputs[i]);
1708 if (input_control_node) {
1709 g->AddControlEdge(input_control_node, n);
1710 }
1711 for (const Edge* e : arg->out_edges()) {
1712 if (e->IsControlEdge()) {
1713 g->AddControlEdge(n, e->dst());
1714 } else {
1715 g->AddEdge(n, 0, e->dst(), e->dst_input());
1716 }
1717 }
1718 node_map[fbody->arg_nodes[i]->id()] = n;
1719 g->RemoveNode(arg); // 'arg' is disconnected.
1720 }
1721
1722 // ------------------------------------------------------------------------ //
1723 // Connect output edges.
1724 //
1725 // For i-th return node in fbody->graph, we add in "g" an identity node
1726 // (outputs[i-th]). We then reconnect every incoming edge into the i-th return
1727 // node to the added identity node.
1728 //
1729 // For every data edge coming out of "callee"s i-th output, we reconnect it to
1730 // the i-th identity added above.
1731 //
1732 // If "callee" is control-depended upon by any other nodes, we add a NoOp node
1733 // "output_control_node". "output_control_node" depends on all identity nodes
1734 // added above or on all control return nodes (controlled by
1735 // `options.output_control_src` value). And nodes previously depend on
1736 // "callee" is changed to depend on "output_control_node".
1737 std::vector<Node*> outputs(caller->num_outputs());
1738 for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
1739 Node* ret = node_map[fbody->ret_nodes[i]->id()];
1740 Endpoint data; // Data input for the ret node.
1741 for (const Edge* e : ret->in_edges()) {
1742 if (!e->IsControlEdge()) {
1743 data = {e->src(), e->src_output()};
1744 break;
1745 }
1746 }
1747 CHECK(data.node != nullptr);
1748 Node* n = identity("output", data);
1749 outputs[i] = n;
1750 for (const Edge* e : ret->in_edges()) {
1751 if (e->IsControlEdge()) {
1752 g->AddControlEdge(e->src(), n);
1753 }
1754 }
1755 g->RemoveNode(ret); // 'ret' is disconnected.
1756 }
1757 Node* output_control_node = nullptr;
1758 for (const Edge* e : caller->out_edges()) {
1759 if (e->IsControlEdge()) {
1760 if (output_control_node == nullptr) {
1761 output_control_node = no_op("output_control_node");
1762 if (options.output_control_src ==
1763 InlineFunctionBodyOptions::OutputControlSource::kDataOutputs) {
1764 for (Node* n : outputs) {
1765 g->AddControlEdge(n, output_control_node);
1766 }
1767 } else {
1768 for (Node* fbody_node : fbody->control_ret_nodes) {
1769 Node* n = node_map[fbody_node->id()];
1770 g->AddControlEdge(n, output_control_node);
1771 }
1772 }
1773 }
1774 g->AddControlEdge(output_control_node, e->dst());
1775 } else {
1776 g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
1777 }
1778 }
1779 g->RemoveNode(caller); // 'caller' is replaced with inlined nodes.
1780
1781 return Status::OK();
1782 }
1783
IsFunctionCall(const FunctionLibraryDefinition & lib_def,const Node & node)1784 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
1785 const Node& node) {
1786 return node.IsPartitionedCall() ||
1787 node.type_string() == FunctionLibraryDefinition::kGradientOp ||
1788 lib_def.Find(node.def().op()) != nullptr;
1789 }
1790
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph,const ExpandInlineFunctionsOptions & options)1791 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
1792 const ExpandInlineFunctionsOptions& options) {
1793 std::vector<std::pair<Node*, const FunctionBody*>> candidates;
1794
1795 const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
1796
1797 for (Node* node : graph->nodes()) {
1798 // Skip nodes that are not function calls or SymbolicGradient calls.
1799 if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
1800 continue;
1801 }
1802 // Skip function calls that marked noinline.
1803 bool noinline;
1804 if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
1805 VLOG(3) << "noinline: " << SummarizeNode(*node);
1806 continue;
1807 }
1808 FunctionLibraryRuntime::Handle handle;
1809 Status s = InstantiateFunctionCall(node->def(), *lib, &handle);
1810 if (!s.ok()) {
1811 LOG(ERROR) << "Failed to instantiate a function: " << s.error_message();
1812 continue;
1813 }
1814 const FunctionBody* fbody = lib->GetFunctionBody(handle);
1815 CHECK_NOTNULL(fbody);
1816 candidates.emplace_back(node, fbody);
1817 }
1818
1819 bool inlined_any = false;
1820 for (const auto& p : candidates) {
1821 Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
1822 p.first->IsPartitionedCall()
1823 ? options.multi_device_options
1824 : options.native_options);
1825 if (inlined.ok()) {
1826 inlined_any = true;
1827 } else {
1828 VLOG(1) << "Failed to inline function call: node=" << p.first->name()
1829 << " error=" << inlined.error_message();
1830 }
1831 }
1832
1833 // TODO(ezhulenev): Release handles for inlined function calls.
1834
1835 return inlined_any;
1836 }
1837
NewName(const Node * n,bool pretty)1838 string NewName(const Node* n, bool pretty) {
1839 if (pretty) {
1840 return strings::StrCat(n->type_string(), n->id());
1841 } else {
1842 return strings::StrCat("n", n->id());
1843 }
1844 }
1845
1846 // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
1847 // and stash the original NodeDef name as an attr for documentation
1848 // purpose.
ToGraphDef(const Graph * g,GraphDef * gdef,bool pretty)1849 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
1850 // We visit nodes in forward topological sort order, which is a
1851 // possible execution order of the graph.
1852 gtl::InlinedVector<const Edge*, 4> inputs;
1853 gdef->Clear();
1854 gdef->mutable_versions()->CopyFrom(g->versions());
1855
1856 std::vector<Node*> start_nodes;
1857 for (Node* n : g->nodes()) {
1858 if (n->out_edges().empty()) {
1859 start_nodes.push_back(n);
1860 }
1861 }
1862
1863 ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
1864 if (!n->IsOp()) return;
1865 NodeDef* ndef = gdef->add_node();
1866 ndef->set_name(NewName(n, pretty));
1867 ndef->set_op(n->type_string());
1868 for (const auto& attr : n->attrs()) {
1869 (*ndef->mutable_attr())[attr.first] = attr.second;
1870 }
1871
1872 if (!n->assigned_device_name().empty()) {
1873 ndef->set_device(n->assigned_device_name());
1874 } else {
1875 ndef->set_device(n->requested_device());
1876 }
1877
1878 inputs.clear();
1879 inputs.resize(n->num_inputs());
1880 for (const Edge* e : n->in_edges()) {
1881 if (e->IsControlEdge()) {
1882 inputs.push_back(e);
1883 } else {
1884 if (inputs[e->dst_input()] == nullptr) {
1885 inputs[e->dst_input()] = e;
1886 } else {
1887 LOG(WARNING) << "Malformed graph node. multiple input edges: "
1888 << n->DebugString();
1889 }
1890 }
1891 }
1892 // node->name() is merely NodeDef::name, which are not guaranteed
1893 // to be unique and stable after optimization rewrites. Therefore,
1894 // we use "n<node id>" instead.
1895 for (const Edge* e : inputs) {
1896 if (e == nullptr) {
1897 ndef->add_input("unknown");
1898 continue;
1899 }
1900 const string srcname = NewName(e->src(), pretty);
1901 if (!e->src()->IsOp()) {
1902 } else if (e->IsControlEdge()) {
1903 ndef->add_input(strings::StrCat("^", srcname));
1904 } else if (e->src_output() == 0) {
1905 ndef->add_input(srcname);
1906 } else {
1907 ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
1908 }
1909 }
1910 });
1911 }
1912
DebugString(const Graph * g)1913 string DebugString(const Graph* g) {
1914 GraphDef gdef;
1915 ToGraphDef(g, &gdef);
1916 return DebugString(gdef);
1917 }
1918
FunctionBody(const FunctionDef & f,DataTypeSlice arg_t,DataTypeSlice ret_t,Graph * g)1919 FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
1920 DataTypeSlice ret_t, Graph* g)
1921 : fdef(f),
1922 graph(g),
1923 arg_types(arg_t.begin(), arg_t.end()),
1924 ret_types(ret_t.begin(), ret_t.end()) {
1925 // 1. Find regular Arg/Ret nodes.
1926 this->arg_nodes.resize(arg_types.size());
1927 this->ret_nodes.resize(ret_types.size());
1928 for (Node* n : this->graph->op_nodes()) {
1929 gtl::InlinedVector<Node*, 4>* node_vec;
1930 if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) {
1931 node_vec = &this->ret_nodes;
1932 } else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) {
1933 node_vec = &this->arg_nodes;
1934 } else {
1935 continue;
1936 }
1937 int index;
1938 TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
1939 CHECK_LE(0, index);
1940 CHECK_LT(index, node_vec->size());
1941 (*node_vec)[index] = n;
1942 }
1943 // 2. Find ControlRet nodes that must be always executed.
1944 std::unordered_set<StringPiece, StringPieceHasher> control_ret_node_names;
1945 for (const auto& control_ret : fdef.control_ret()) {
1946 control_ret_node_names.insert(control_ret.second);
1947 }
1948 this->control_ret_nodes.reserve(control_ret_node_names.size());
1949 for (Node* n : this->graph->op_nodes()) {
1950 if (control_ret_node_names.count(n->name()) > 0) {
1951 this->control_ret_nodes.push_back(n);
1952 }
1953 }
1954 }
1955
~FunctionBody()1956 FunctionBody::~FunctionBody() { delete this->graph; }
1957
1958 class SymbolicGradientHelper {
1959 public:
SymbolicGradientHelper(const FunctionBody & f)1960 explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1961
~SymbolicGradientHelper()1962 ~SymbolicGradientHelper() { delete gbody_; }
1963
1964 FunctionBody* Compute();
1965
1966 private:
1967 const FunctionBody* fbody_;
1968 FunctionBody* gbody_ = nullptr;
1969
1970 // Makes a copy of fbody_ in gbody_.
1971 void Copy();
1972
1973 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1974 };
1975
Copy()1976 void SymbolicGradientHelper::Copy() {
1977 const Graph& src = *(fbody_->graph);
1978 gbody_->graph = new Graph(src.op_registry());
1979 Graph* dst = gbody_->graph;
1980
1981 std::vector<Node*> node_map(src.num_node_ids());
1982
1983 // Copy the nodes.
1984 node_map[src.source_node()->id()] = dst->source_node();
1985 node_map[src.sink_node()->id()] = dst->sink_node();
1986 for (Node* n : src.op_nodes()) {
1987 node_map[n->id()] = dst->CopyNode(n);
1988 }
1989
1990 // Copy the edges.
1991 for (const Edge* e : src.edges()) {
1992 Node* src_copy = node_map[e->src()->id()];
1993 Node* dst_copy = node_map[e->dst()->id()];
1994 dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1995 }
1996
1997 // Save inputs in copied graph.
1998 CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1999 gbody_->arg_types = fbody_->arg_types;
2000 for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
2001 gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
2002 }
2003
2004 // Save outputs in copied graph.
2005 CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
2006 gbody_->ret_types = fbody_->ret_types;
2007 for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
2008 gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
2009 }
2010 }
2011
Compute()2012 FunctionBody* SymbolicGradientHelper::Compute() {
2013 CHECK(gbody_ == nullptr);
2014 gbody_ = new FunctionBody;
2015
2016 // Copy fbody_ into gbody_.
2017 Copy();
2018
2019 Graph* g = gbody_->graph;
2020
2021 const int num_y = static_cast<int>(gbody_->ret_nodes.size());
2022
2023 // Populate 'y_node_outputs_' with node function body outputs.
2024 // Populate 'y_grad_nodes' with initial gradient nodes for each return node
2025 // of the original function body (these will be 'arg' nodes in the function
2026 // gradient body).
2027 std::vector<NodeOut> y_node_outputs;
2028 y_node_outputs.reserve(num_y);
2029 std::vector<NodeOut> y_grad_node_outputs;
2030 y_grad_node_outputs.reserve(num_y);
2031 for (int i = 0; i < num_y; ++i) {
2032 Node* y = gbody_->ret_nodes[i];
2033 y_node_outputs.push_back({y, 0});
2034 DCHECK_EQ(y->type_string(), kRetOp);
2035 const DataType dtype = y->input_type(0);
2036 const int index = static_cast<int>(gbody_->arg_nodes.size());
2037 Node* dy = AddArg(g, dtype, index);
2038 gbody_->arg_types.push_back(dtype);
2039 gbody_->arg_nodes.push_back(dy);
2040 y_grad_node_outputs.push_back({dy, 0});
2041 }
2042
2043 // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
2044 const size_t num_x = fbody_->arg_nodes.size();
2045 std::vector<NodeOut> x_node_outputs;
2046 x_node_outputs.reserve(num_x);
2047 for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
2048 x_node_outputs.push_back({gbody_->arg_nodes[i], 0});
2049 }
2050
2051 // Call AddSymbolicGradients which will add nodes to graph 'g' that
2052 // compute the function gradient (adding an entry in 'x_grad_node_outputs'
2053 // for each node in 'x_node_outputs').
2054 std::vector<NodeOut> x_grad_node_outputs;
2055 TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
2056 y_grad_node_outputs, &x_grad_node_outputs,
2057 g));
2058
2059 // Remove the old return nodes from the function body.
2060 for (Node* n : gbody_->ret_nodes) {
2061 g->RemoveNode(n);
2062 }
2063 gbody_->ret_types = fbody_->arg_types;
2064 // TODO(apassos): use the right dtype for gradients of resource variables
2065 for (int i = 0; i < gbody_->ret_types.size(); ++i) {
2066 if (gbody_->ret_types[i] == DT_RESOURCE) {
2067 gbody_->ret_types[i] = DT_FLOAT;
2068 }
2069 }
2070 gbody_->ret_nodes.clear();
2071 // Add new return nodes to the function gradient body for each node
2072 // in 'x_grad_nodes'.
2073 const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
2074 for (int i = 0; i < arg_types_size; ++i) {
2075 Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
2076 Node* ret = AddRet(g, grad, i);
2077 gbody_->ret_nodes.push_back(ret);
2078 }
2079
2080 auto ret = gbody_;
2081 gbody_ = nullptr;
2082 return ret;
2083 }
2084
SymbolicGradient(const FunctionBody & f)2085 FunctionBody* SymbolicGradient(const FunctionBody& f) {
2086 return SymbolicGradientHelper(f).Compute();
2087 }
2088
FunctionDefToBodyHelper(const FunctionDef & fdef,const AttrSlice & attrs,const FunctionLibraryDefinition * const lib_def,const std::function<Status (const string &,const OpDef **)> & get_func_sig,FunctionBody ** fbody)2089 Status FunctionDefToBodyHelper(
2090 const FunctionDef& fdef, const AttrSlice& attrs,
2091 const FunctionLibraryDefinition* const lib_def,
2092 const std::function<Status(const string&, const OpDef**)>& get_func_sig,
2093 FunctionBody** fbody) {
2094 // Instantiates the function template into a graph def.
2095 InstantiationResult result;
2096 TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
2097
2098 std::unique_ptr<Graph> graph(new Graph(lib_def));
2099 GraphConstructorOptions opts;
2100 opts.allow_internal_ops = true;
2101 opts.expect_device_spec = false;
2102 TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
2103
2104 // Call BuildControlFlowInfo to validate that this function body has
2105 // well-formed control flow.
2106 std::vector<ControlFlowInfo> dummy;
2107 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
2108
2109 *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types,
2110 graph.release());
2111 return Status::OK();
2112 }
2113
2114 } // end namespace tensorflow
2115