1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/function.h"
17
18 #include <deque>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/graph_optimizer.h"
24 #include "tensorflow/core/common_runtime/memory_types.h"
25 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/control_flow.h"
34 #include "tensorflow/core/graph/gradients.h"
35 #include "tensorflow/core/graph/graph_constructor.h"
36 #include "tensorflow/core/graph/optimizer_cse.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/platform/macros.h"
39
40 // See core/kernels/function_ops.cc for related kernels.
41
42 namespace tensorflow {
43
44 // A few string constant used throughout this module.
45 //
46 // TODO(zhifengc): Dedup some of these constants into
47 // framework/function.h
48 static constexpr const char* const kArgOp = "_Arg";
49 static constexpr const char* const kRetOp = "_Retval";
50 static constexpr const char* const kGradientOp =
51 FunctionLibraryDefinition::kGradientOp;
52 static constexpr const char* const kNodeLabel = "Func";
53 static constexpr const char* const kFuncAttr =
54 FunctionLibraryDefinition::kFuncAttr;
55
56 // Represents the index-th output of a node.
57 struct Endpoint {
58 Node* node;
59 int index;
60
61 // Returns the string name represents this endpoint.
nametensorflow::Endpoint62 string name() const {
63 if (index == 0) {
64 return node->name();
65 } else {
66 return strings::StrCat(node->name(), ":", index);
67 }
68 }
69
dtypetensorflow::Endpoint70 DataType dtype() const { return node->output_type(index); }
71 };
72
73 struct EndpointHash {
operator ()tensorflow::EndpointHash74 uint64 operator()(const Endpoint& x) const {
75 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
76 x.index);
77 }
78 };
79
80 struct EndpointEq {
operator ()tensorflow::EndpointEq81 bool operator()(const Endpoint& x, const Endpoint& y) const {
82 return (x.node == y.node) && (x.index == y.index);
83 }
84 };
85
86 // The following Add* routines are used to add a few graph nodes while
87 // functions are transformed.
AddNoOp(Graph * g)88 static Node* AddNoOp(Graph* g) {
89 NodeDef ndef;
90 ndef.set_name(g->NewName(kNodeLabel));
91 ndef.set_op("NoOp");
92 Status s;
93 Node* ret = g->AddNode(ndef, &s);
94 TF_CHECK_OK(s);
95 return ret;
96 }
97
AddIdentity(Graph * g,Endpoint input)98 static Node* AddIdentity(Graph* g, Endpoint input) {
99 DCHECK_LT(0, input.dtype());
100 NodeDef ndef;
101 ndef.set_name(g->NewName(kNodeLabel));
102 ndef.set_op("Identity");
103 ndef.add_input(input.name());
104 AddNodeAttr("T", BaseType(input.dtype()), &ndef);
105 Status s;
106 Node* ret = g->AddNode(ndef, &s);
107 TF_CHECK_OK(s);
108 g->AddEdge(input.node, input.index, ret, 0);
109 return ret;
110 }
111
AddArg(Graph * g,DataType dtype,int index)112 static Node* AddArg(Graph* g, DataType dtype, int index) {
113 DCHECK_LT(0, dtype);
114 DCHECK_LT(dtype, DT_FLOAT_REF);
115 NodeDef ndef;
116 ndef.set_name(g->NewName(kNodeLabel));
117 ndef.set_op(kArgOp);
118 AddNodeAttr("T", dtype, &ndef);
119 AddNodeAttr("index", index, &ndef);
120 Status s;
121 Node* ret = g->AddNode(ndef, &s);
122 TF_CHECK_OK(s);
123 return ret;
124 }
125
AddRet(Graph * g,Endpoint input,int index)126 static Node* AddRet(Graph* g, Endpoint input, int index) {
127 DCHECK_LT(0, input.dtype());
128 DCHECK_LT(input.dtype(), DT_FLOAT_REF);
129 NodeDef ndef;
130 ndef.set_name(g->NewName(kNodeLabel));
131 ndef.set_op(kRetOp);
132 ndef.add_input(input.name());
133 AddNodeAttr("T", input.dtype(), &ndef);
134 AddNodeAttr("index", index, &ndef);
135 Status s;
136 Node* ret = g->AddNode(ndef, &s);
137 TF_CHECK_OK(s);
138 g->AddEdge(input.node, input.index, ret, 0);
139 return ret;
140 }
141
142 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
143 public:
144 FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
145 int graph_def_version,
146 const FunctionLibraryDefinition* lib_def,
147 const OptimizerOptions& optimizer_options,
148 CustomKernelCreator custom_kernel_creator,
149 ProcessFunctionLibraryRuntime* parent);
150
151 ~FunctionLibraryRuntimeImpl() override;
152
153 Status Instantiate(const string& function_name, AttrSlice attrs,
154 const InstantiateOptions& options,
155 Handle* handle) override;
156
157 Status ReleaseHandle(Handle handle) override;
158
159 const FunctionBody* GetFunctionBody(Handle handle) override;
160
161 Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
162
163 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
164 std::vector<Tensor>* rets, DoneCallback done) override;
165 // NOTE(mrry): This overload is currently only implemented for local function
166 // execution.
167 // TODO(b/70346412): Implement support for remote function execution when
168 // passing a call frame.
169 void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
170 DoneCallback done) override;
171
172 bool IsStateful(const string& function) override;
173
GetFunctionLibraryDefinition() const174 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
175 const override {
176 return base_lib_def_;
177 }
178
device()179 Device* device() override { return device_; }
env()180 Env* env() override { return env_; }
graph_def_version()181 int graph_def_version() override { return graph_def_version_; }
182
183 string DebugString(Handle h) override;
184
185 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
186 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
187 FunctionLibraryRuntime** out_flr) override;
188
189 private:
190 typedef FunctionLibraryRuntimeImpl ME;
191
192 const DeviceMgr* const device_mgr_;
193 Device* const device_;
194 Env* const env_;
195 const int graph_def_version_;
196 const FunctionLibraryDefinition* const base_lib_def_;
197 GraphOptimizer optimizer_;
198 const CustomKernelCreator custom_kernel_creator_;
199 const string device_name_;
200
201 std::function<Status(const string&, const OpDef**)> get_func_sig_;
202 std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
203
204 mutable mutex mu_;
205
206 int next_handle_ GUARDED_BY(mu_);
207
208 // The instantiated and transformed function is encoded as a Graph
209 // object, and an executor is created for the graph.
210 struct Item : public core::RefCounted {
211 const Graph* graph = nullptr; // Owned by exec.
212 const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
213 FunctionBody* func_graph = nullptr;
214 Executor* exec = nullptr;
215
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item216 ~Item() override {
217 delete this->func_graph;
218 delete this->exec;
219 }
220 };
221 std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_);
222
223 ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
224
225 Status CreateKernel(const NodeDef& ndef,
226 const FunctionLibraryDefinition* lib_def,
227 OpKernel** kernel);
228 Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
229 const FunctionLibraryDefinition* lib_def,
230 FunctionBody** fbody);
231 Status CreateItem(Handle handle, Item** item);
232 Status GetOrCreateItem(Handle handle, Item** item);
233 Status InstantiateSymbolicGradient(const NameAttrList& func,
234 const FunctionLibraryDefinition* lib_def,
235 FunctionBody** g_body);
236 bool IsLocalTarget(const InstantiateOptions& options);
237 AttrValueMap FixAttrs(const AttrSlice& attrs);
238 void RunRemote(const Options& opts, Handle handle,
239 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
240 Executor::Args* exec_args, Item* item, DoneCallback done);
241
242 TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
243 };
244
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,ProcessFunctionLibraryRuntime * parent)245 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
246 const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
247 const FunctionLibraryDefinition* lib_def,
248 const OptimizerOptions& optimizer_options,
249 CustomKernelCreator custom_kernel_creator,
250 ProcessFunctionLibraryRuntime* parent)
251 : device_mgr_(dmgr),
252 device_(device),
253 env_(env),
254 graph_def_version_(graph_def_version),
255 base_lib_def_(lib_def),
256 optimizer_(optimizer_options),
257 custom_kernel_creator_(std::move(custom_kernel_creator)),
258 device_name_(device_ == nullptr
259 ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
260 : device_->name()),
261 next_handle_(0),
262 parent_(parent) {
263 get_func_sig_ = [this](const string& op, const OpDef** sig) {
264 return base_lib_def_->LookUpOpDef(op, sig);
265 };
266 create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
267 return CreateKernel(ndef, kernel);
268 };
269 }
270
~FunctionLibraryRuntimeImpl()271 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
272 // The most common patterns of FLR usage don't require the caller to
273 // explicitly release handles. As a result, we try to unref each item until
274 // it's erased.
275 for (auto item : items_) {
276 if (item.second) {
277 while (!item.second->Unref()) {
278 }
279 }
280 }
281 }
282
283 // An asynchronous op kernel which executes an instantiated function
284 // defined in a library.
285 class CallOp : public AsyncOpKernel {
286 public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)287 CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
288 : AsyncOpKernel(ctx), handle_(handle) {}
289
~CallOp()290 ~CallOp() override {}
291
ComputeAsync(OpKernelContext * ctx,DoneCallback done)292 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
293 FunctionLibraryRuntime* lib = ctx->function_library();
294 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
295 errors::Internal("No function library is provided."),
296 done);
297 FunctionLibraryRuntime::Options opts;
298 opts.step_id = ctx->step_id();
299 opts.rendezvous = ctx->rendezvous();
300 opts.cancellation_manager = ctx->cancellation_manager();
301 opts.step_container = ctx->step_container();
302 opts.stats_collector = ctx->stats_collector();
303 opts.runner = ctx->runner();
304 std::vector<Tensor> args;
305 args.reserve(ctx->num_inputs());
306 for (int i = 0; i < ctx->num_inputs(); ++i) {
307 args.push_back(ctx->input(i));
308 }
309 std::vector<Tensor>* rets = new std::vector<Tensor>;
310 lib->Run(opts, handle_, args, rets,
311 [ctx, done, rets](const Status& status) {
312 if (!status.ok()) {
313 ctx->SetStatus(status);
314 } else {
315 const int ret_size = static_cast<int>(rets->size());
316 CHECK_EQ(ret_size, ctx->num_outputs());
317 for (int i = 0; i < ret_size; ++i) {
318 ctx->set_output(i, (*rets)[i]);
319 }
320 }
321 delete rets;
322 done();
323 });
324 }
325
326 private:
327 FunctionLibraryRuntime::Handle handle_;
328
329 TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
330 };
331
GetFunctionBody(Handle h)332 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
333 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
334 if (local_handle == kInvalidLocalHandle) {
335 LOG(ERROR) << "Could not find Handle: " << h
336 << " on device: " << device_name_;
337 return nullptr;
338 }
339
340 mutex_lock l(mu_);
341 CHECK_EQ(1, items_.count(local_handle));
342 return items_[local_handle]->func_graph;
343 }
344
CreateKernel(const NodeDef & ndef,OpKernel ** kernel)345 Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
346 OpKernel** kernel) {
347 return CreateKernel(ndef, base_lib_def_, kernel);
348 }
349
CreateKernel(const NodeDef & ndef,const FunctionLibraryDefinition * lib_def,OpKernel ** kernel)350 Status FunctionLibraryRuntimeImpl::CreateKernel(
351 const NodeDef& ndef, const FunctionLibraryDefinition* lib_def,
352 OpKernel** kernel) {
353 // If a custom kernel creator is given, try that.
354 Status s;
355 if (custom_kernel_creator_) {
356 std::unique_ptr<OpKernel> ret;
357 s = custom_kernel_creator_(this, ndef, &ret);
358 if (s.ok()) {
359 *kernel = ret.release();
360 return s;
361 } else {
362 VLOG(2) << "Custom creator error: " << s;
363 // Falls through.
364 s = Status::OK();
365 }
366 }
367
368 if (lib_def->Find(ndef.op()) == nullptr) {
369 // A primitive operation. Creates the registered kernel.
370 return CreateNonCachedKernel(device_, this, ndef, graph_def_version_,
371 kernel);
372 }
373
374 // Try to instantiate this function for the func/attr. Maybe it's
375 // cached already.
376 InstantiateOptions options;
377 if (lib_def != base_lib_def_) {
378 options.overlay_lib = lib_def;
379 }
380 Handle handle;
381 TF_RETURN_IF_ERROR(
382 Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle));
383
384 const FunctionBody* fbody = GetFunctionBody(handle);
385 CHECK_NOTNULL(fbody);
386
387 // TODO(zhifengc): For now, we assume int32 and resources are always on host
388 // memory and other types are always on device memory. We should do type
389 // inference over function body to derive the correct input/output memory
390 // types.
391 MemoryTypeVector input_memory_types;
392 for (const auto& t : fbody->arg_types) {
393 input_memory_types.push_back(
394 (t == DT_INT32 || t == DT_RESOURCE) ? HOST_MEMORY : DEVICE_MEMORY);
395 }
396 MemoryTypeVector output_memory_types;
397 for (const auto& t : fbody->ret_types) {
398 output_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY);
399 }
400
401 // Constructs a CallOp kernel for running the instantiated function.
402 auto device_type = DeviceType(device_->attributes().device_type());
403 OpKernelConstruction construction(
404 device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
405 &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
406 fbody->ret_types, output_memory_types, graph_def_version_, &s);
407 *kernel = new CallOp(handle, &construction);
408 if (!s.ok()) {
409 delete *kernel;
410 }
411 return s;
412 }
413
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,FunctionBody ** fbody)414 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
415 const FunctionDef& fdef, AttrSlice attrs,
416 const FunctionLibraryDefinition* lib_def, FunctionBody** fbody) {
417 if (lib_def == base_lib_def_) {
418 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
419 } else {
420 auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
421 return lib_def->LookUpOpDef(op, sig);
422 };
423 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
424 }
425 }
426
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,FunctionBody ** g_body)427 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
428 const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
429 FunctionBody** g_body) {
430 const FunctionDef* fdef = lib_def->Find(func.name());
431 if (fdef == nullptr) {
432 // f is a primitive op.
433 gradient::Creator creator;
434 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
435 if (creator == nullptr) {
436 return errors::InvalidArgument("No gradient is defined for ",
437 func.name());
438 }
439 FunctionDef grad_fdef;
440 // TODO(josh11b): Should filter out the attrs from func that aren't used
441 // by the gradient function.
442 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
443 TF_RETURN_IF_ERROR(
444 FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
445 } else {
446 // f is a user-defined function.
447 InstantiateOptions options;
448 if (lib_def != base_lib_def_) {
449 options.overlay_lib = lib_def;
450 }
451 Handle f_handle;
452 TF_RETURN_IF_ERROR(
453 Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
454 const FunctionBody* f_body = GetFunctionBody(f_handle);
455 CHECK_NOTNULL(f_body);
456 *g_body = SymbolicGradient(*f_body);
457 }
458 return Status::OK();
459 }
460
IsLocalTarget(const InstantiateOptions & options)461 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
462 const InstantiateOptions& options) {
463 if (device_ == nullptr) return true;
464 if (options.target.empty()) return true;
465 Device* target_device;
466 if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
467 return false;
468 }
469 return target_device == device_;
470 }
471
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)472 Status FunctionLibraryRuntimeImpl::Instantiate(
473 const string& function_name, AttrSlice attrs,
474 const InstantiateOptions& options, Handle* handle) {
475 if (!IsLocalTarget(options)) {
476 return parent_->Instantiate(function_name, attrs, options, handle);
477 }
478
479 // Since this is a local target, ensure that the local `device_name_` appears
480 // in the canonical key.
481 InstantiateOptions options_copy(options);
482 options_copy.target = device_name_;
483 const string key = Canonicalize(function_name, attrs, options_copy);
484 *handle = parent_->GetHandle(key);
485 if (*handle != kInvalidHandle) {
486 mutex_lock l(mu_);
487 items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
488 return Status::OK();
489 }
490
491 Status s;
492 const FunctionLibraryDefinition* lib_def =
493 options.overlay_lib ? options.overlay_lib : base_lib_def_;
494 FunctionBody* fbody = nullptr;
495 if (function_name == kGradientOp) {
496 const AttrValue* f = attrs.Find(kFuncAttr);
497 if (f == nullptr) {
498 return errors::InvalidArgument("SymbolicGradient is missing attr: f");
499 }
500 const auto& func = f->func();
501 if (func.name() == kGradientOp) {
502 return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
503 }
504 const string grad = lib_def->FindGradient(func.name());
505 if (!grad.empty()) {
506 return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
507 }
508 TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
509 } else {
510 const FunctionDef* fdef = lib_def->Find(function_name);
511 if (fdef == nullptr) {
512 return errors::NotFound("Function ", function_name, " is not defined.");
513 }
514 TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
515 }
516
517 {
518 mutex_lock l(mu_);
519 *handle = parent_->GetHandle(key);
520 if (*handle != kInvalidHandle) {
521 delete fbody;
522 items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
523 } else {
524 *handle = parent_->AddHandle(key, device_name_, next_handle_);
525 Item* item = new Item;
526 item->func_graph = fbody;
527 item->overlay_lib = options.overlay_lib;
528 items_.insert({next_handle_, item});
529 next_handle_++;
530 }
531 }
532 return Status::OK();
533 }
534
ReleaseHandle(Handle handle)535 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
536 if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
537 return parent_->ReleaseHandle(handle);
538 }
539
540 LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
541 mutex_lock l(mu_);
542 CHECK_EQ(1, items_.count(h));
543 Item* item = items_[h];
544 if (item->Unref()) {
545 items_.erase(h);
546 TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
547 }
548 return Status::OK();
549 }
550
DumpGraph(StringPiece label,const Graph * g)551 void DumpGraph(StringPiece label, const Graph* g) {
552 // TODO(zhifengc): Change Graph to record #nodes.
553 VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
554 << g->num_edges();
555 if (VLOG_IS_ON(2)) {
556 for (const auto& line : str_util::Split(DebugString(g), '\n')) {
557 VLOG(2) << "|| " << line;
558 }
559 }
560 }
561
OptimizeGraph(FunctionLibraryRuntime * lib,std::unique_ptr<Graph> * g)562 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
563 OptimizerOptions opts;
564 opts.set_do_common_subexpression_elimination(true);
565 opts.set_do_function_inlining(true);
566 opts.set_do_constant_folding(true);
567 GraphOptimizer optimizer(opts);
568 optimizer.Optimize(lib, lib->env(), lib->device(), g, /*shape_map=*/nullptr);
569 }
570
571 namespace {
572 // Removes all stateless nodes that do not contribute to a return
573 // value from the function body. Unlike `RemoveDeadNodes()`, which is
574 // triggered by `OptimizerOptions.do_function_inlining`, this pass
575 // ignores the SINK node, from which (by definition) all nodes are
576 // reverse reachable.
PruneFunctionBody(Graph * g)577 void PruneFunctionBody(Graph* g) {
578 VLOG(2) << "Pruning function body";
579 std::unordered_set<const Node*> nodes;
580 for (auto n : g->nodes()) {
581 // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
582 // to the seed set of `nodes`.
583 // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
584 // still needed. It would be preferable to prune entire loops and/or
585 // conditionals if they are not used in the graph.
586 if (n->IsControlFlow() || n->op_def().is_stateful()) {
587 nodes.insert(n);
588 }
589 }
590 bool changed = PruneForReverseReachability(g, std::move(nodes));
591 if (changed) {
592 FixupSourceAndSinkEdges(g);
593 }
594 }
595 } // namespace
596
CreateItem(Handle handle,Item ** item)597 Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
598 const FunctionBody* fbody;
599 const FunctionLibraryDefinition* lib_def;
600 {
601 mutex_lock l(mu_);
602 fbody = (*item)->func_graph;
603 lib_def = (*item)->overlay_lib;
604 }
605 if (!lib_def) {
606 lib_def = base_lib_def_;
607 }
608 std::unique_ptr<Graph> g(new Graph(lib_def));
609 CopyGraph(*fbody->graph, g.get());
610
611 PruneFunctionBody(g.get());
612 optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
613 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
614 device()->name(), g.get()));
615
616 // Creates an executor based on the g. This must be done without
617 // holding mu_ because create_kernel_ calls back into the library.
618 LocalExecutorParams params;
619 params.device = device_;
620 params.function_library = this;
621 if (lib_def == base_lib_def_) {
622 params.create_kernel = create_kernel_;
623 } else {
624 params.create_kernel = [this, lib_def](const NodeDef& ndef,
625 OpKernel** kernel) {
626 return CreateKernel(ndef, lib_def, kernel);
627 };
628 }
629 params.delete_kernel = [](OpKernel* kernel) {
630 DeleteNonCachedKernel(kernel);
631 };
632 Graph* graph = g.get();
633 Executor* exec;
634 TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec));
635
636 {
637 // Guard item since it is already inserted in items_.
638 mutex_lock l(mu_);
639 if ((*item)->exec) {
640 delete exec;
641 } else {
642 (*item)->graph = graph;
643 (*item)->exec = exec;
644 }
645 }
646 return Status::OK();
647 }
648
GetOrCreateItem(Handle handle,Item ** item)649 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
650 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
651 {
652 mutex_lock l(mu_);
653 if (items_.count(local_handle) == 0) {
654 return errors::NotFound("Function handle ", handle,
655 " is not valid. Likely an internal error.");
656 }
657 *item = items_[local_handle];
658 if ((*item)->exec != nullptr) {
659 return Status::OK();
660 }
661 }
662 // NOTE: We need to call CreateItem out of mu_ because creating an
663 // executor needs to call CreateKernel.
664 return CreateItem(handle, item);
665 }
666
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Executor::Args * exec_args,Item * item,DoneCallback done)667 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
668 gtl::ArraySlice<Tensor> args,
669 std::vector<Tensor>* rets,
670 Executor::Args* exec_args,
671 Item* item, DoneCallback done) {
672 DCHECK(exec_args->call_frame == nullptr);
673 string target_device = parent_->GetDeviceName(handle);
674 string source_device = opts.source_device;
675 Rendezvous* rendezvous = opts.rendezvous;
676 DeviceContext* device_context;
677 Status s = parent_->GetDeviceContext(target_device, &device_context);
678 if (!s.ok()) {
679 delete exec_args;
680 done(s);
681 return;
682 }
683 int64 src_incarnation, target_incarnation;
684 s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
685 s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
686 if (!s.ok()) {
687 delete exec_args;
688 done(s);
689 return;
690 }
691
692 const FunctionBody* fbody = GetFunctionBody(handle);
693 FunctionCallFrame* frame =
694 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
695 exec_args->call_frame = frame;
696 if (!s.ok()) {
697 delete frame;
698 delete exec_args;
699 done(s);
700 return;
701 }
702
703 // The ProcFLR sends the arguments to the function from the source_device to
704 // the target_device. So here we receive those arguments. Similarly, when the
705 // computation is done and stored in *rets, we send the return values back
706 // to the source_device (caller) so that the ProcFLR can receive them later.
707 std::vector<Tensor>* remote_args = new std::vector<Tensor>;
708 ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
709 source_device, target_device, "arg_", src_incarnation, args.size(),
710 device_context, {}, rendezvous, remote_args,
711 [frame, remote_args, item, source_device, target_device,
712 target_incarnation, rendezvous, device_context, rets, done,
713 exec_args](const Status& status) {
714 Status s = status;
715 if (s.ok()) {
716 s = frame->SetArgs(*remote_args);
717 }
718 if (!s.ok()) {
719 delete frame;
720 delete remote_args;
721 delete exec_args;
722 done(s);
723 return;
724 }
725 item->exec->RunAsync(
726 *exec_args, [item, frame, rets, done, source_device, target_device,
727 target_incarnation, rendezvous, device_context,
728 remote_args, exec_args](const Status& status) {
729 Status s = status;
730 if (s.ok()) {
731 s = frame->ConsumeRetvals(rets);
732 }
733 delete frame;
734 if (!s.ok()) {
735 delete remote_args;
736 delete exec_args;
737 done(s);
738 return;
739 }
740 s = ProcessFunctionLibraryRuntime::SendTensors(
741 target_device, source_device, "ret_", target_incarnation,
742 *rets, device_context, {}, rendezvous);
743 delete remote_args;
744 delete exec_args;
745 done(s);
746 });
747 });
748 }
749
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)750 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
751 gtl::ArraySlice<Tensor> args,
752 std::vector<Tensor>* rets,
753 DoneCallback done) {
754 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
755 done(errors::Cancelled(""));
756 return;
757 }
758 Options run_opts = opts;
759 if (opts.create_rendezvous) {
760 Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
761 run_opts.rendezvous = rendezvous;
762 run_opts.create_rendezvous = false;
763 done = [done, rendezvous](const Status& status) {
764 rendezvous->Unref();
765 done(status);
766 };
767 }
768 if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
769 parent_->Run(run_opts, handle, args, rets, done);
770 return;
771 }
772
773 DCHECK(run_opts.runner != nullptr);
774
775 Executor::Args* exec_args = new Executor::Args;
776 // Inherit the step_id from the caller.
777 exec_args->step_id = run_opts.step_id;
778 exec_args->rendezvous = run_opts.rendezvous;
779 exec_args->stats_collector = run_opts.stats_collector;
780 exec_args->cancellation_manager = run_opts.cancellation_manager;
781 exec_args->step_container = run_opts.step_container;
782 exec_args->runner = *run_opts.runner;
783
784 Item* item = nullptr;
785 Status s = GetOrCreateItem(handle, &item);
786 if (!s.ok()) {
787 delete exec_args;
788 done(s);
789 return;
790 }
791
792 if (run_opts.remote_execution) {
793 // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
794 RunRemote(run_opts, handle, args, rets, exec_args, item, done);
795 return;
796 }
797
798 const FunctionBody* fbody = GetFunctionBody(handle);
799 FunctionCallFrame* frame =
800 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
801 exec_args->call_frame = frame;
802 s = frame->SetArgs(args);
803 if (!s.ok()) {
804 delete frame;
805 delete exec_args;
806 done(s);
807 return;
808 }
809
810 item->exec->RunAsync(
811 // Executor args
812 *exec_args,
813 // Done callback.
814 [item, frame, rets, done, exec_args](const Status& status) {
815 Status s = status;
816 if (s.ok()) {
817 s = frame->ConsumeRetvals(rets);
818 }
819 delete frame;
820 delete exec_args;
821 done(s);
822 });
823 }
824
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)825 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
826 CallFrameInterface* frame,
827 DoneCallback done) {
828 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
829 done(errors::Cancelled(""));
830 return;
831 }
832 if (!parent_->IsInstantiatedOnDevice(device_name_, handle) ||
833 opts.remote_execution) {
834 done(errors::Unimplemented("Remote calling with CallFrameInterface"));
835 return;
836 }
837
838 Options run_opts = opts;
839 if (opts.create_rendezvous) {
840 Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
841 run_opts.rendezvous = rendezvous;
842 run_opts.create_rendezvous = false;
843 done = std::bind(
844 [rendezvous](DoneCallback done,
845 // Begin unbound arguments.
846 const Status& status) {
847 rendezvous->Unref();
848 done(status);
849 },
850 std::move(done), std::placeholders::_1);
851 }
852
853 Item* item = nullptr;
854 Status s = GetOrCreateItem(handle, &item);
855 if (!s.ok()) {
856 done(s);
857 return;
858 }
859 DCHECK(run_opts.runner != nullptr);
860
861 Executor::Args* exec_args = new Executor::Args;
862 // Inherit the step_id from the caller.
863 exec_args->step_id = run_opts.step_id;
864 exec_args->rendezvous = run_opts.rendezvous;
865 exec_args->stats_collector = run_opts.stats_collector;
866 exec_args->cancellation_manager = run_opts.cancellation_manager;
867 exec_args->step_container = run_opts.step_container;
868 exec_args->runner = *run_opts.runner;
869 exec_args->call_frame = frame;
870
871 item->exec->RunAsync(
872 // Executor args
873 *exec_args,
874 // Done callback.
875 std::bind(
876 [item, frame, exec_args](DoneCallback done,
877 // Start unbound arguments.
878 const Status& status) {
879 delete exec_args;
880 done(status);
881 },
882 std::move(done), std::placeholders::_1));
883 }
884
IsStateful(const string & func)885 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
886 const OpDef* op_def;
887 const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
888 return s.ok() && op_def->is_stateful();
889 }
890
DebugString(Handle handle)891 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
892 Item* item = nullptr;
893 Status s = GetOrCreateItem(handle, &item);
894 if (s.ok()) {
895 return tensorflow::DebugString(item->graph);
896 } else {
897 return s.ToString();
898 }
899 }
900
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr)901 Status FunctionLibraryRuntimeImpl::Clone(
902 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
903 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
904 FunctionLibraryRuntime** out_flr) {
905 TF_RETURN_IF_ERROR(
906 parent_->Clone(env_, graph_def_version_, optimizer_.options(),
907 custom_kernel_creator_, out_lib_def, out_pflr));
908 *out_flr = (*out_pflr)->GetFLR(device_->name());
909 if (out_flr != nullptr) {
910 return Status::OK();
911 } else {
912 return errors::Internal("Cloning FunctionLibraryRuntime failed.");
913 }
914 }
915
916 namespace {
917
918 struct CustomCreatorSingleton {
919 mutex mu;
920 CustomKernelCreator custom_creator = nullptr;
921
Settensorflow::__anon849d4f9f0e11::CustomCreatorSingleton922 void Set(CustomKernelCreator cb) {
923 mutex_lock l(mu);
924 custom_creator = std::move(cb);
925 }
926
Gettensorflow::__anon849d4f9f0e11::CustomCreatorSingleton927 CustomKernelCreator Get() {
928 mutex_lock l(mu);
929 return custom_creator;
930 }
931 };
932
GetCustomCreatorSingleton()933 CustomCreatorSingleton* GetCustomCreatorSingleton() {
934 static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
935 return ccs;
936 }
937
938 } // namespace
939
RegisterDefaultCustomKernelCreator(CustomKernelCreator cb)940 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
941 GetCustomCreatorSingleton()->Set(std::move(cb));
942 }
943
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,ProcessFunctionLibraryRuntime * parent)944 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
945 const DeviceMgr* device_mgr, Env* env, Device* device,
946 int graph_def_version, const FunctionLibraryDefinition* lib_def,
947 const OptimizerOptions& optimizer_options,
948 CustomKernelCreator custom_kernel_creator,
949 ProcessFunctionLibraryRuntime* parent) {
950 return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
951 device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
952 std::move(custom_kernel_creator), parent));
953 }
954
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,ProcessFunctionLibraryRuntime * parent)955 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
956 const DeviceMgr* device_mgr, Env* env, Device* device,
957 int graph_def_version, const FunctionLibraryDefinition* lib_def,
958 const OptimizerOptions& optimizer_options,
959 ProcessFunctionLibraryRuntime* parent) {
960 return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
961 lib_def, optimizer_options,
962 GetCustomCreatorSingleton()->Get(), parent);
963 }
964
RemoveDeadNodes(Graph * g)965 bool RemoveDeadNodes(Graph* g) {
966 VLOG(2) << "Removing dead nodes";
967 std::unordered_set<const Node*> nodes;
968 for (auto n : g->nodes()) {
969 if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
970 n->op_def().is_stateful()) {
971 nodes.insert(n);
972 }
973 }
974 return PruneForReverseReachability(g, std::move(nodes));
975 }
976
977 namespace {
978 // If 'edges' contains only 1 non-control edge, returns it. Otherwise,
979 // returns a nullptr.
GetTheOnlyDataEdge(const EdgeSet & edges)980 const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
981 const Edge* ret = nullptr;
982 for (const Edge* e : edges) {
983 if (e->IsControlEdge() || ret) {
984 // Don't touch it if there is a control edge.
985 return nullptr;
986 }
987 if (IsRefType(e->src()->output_type(e->src_output()))) {
988 // Don't touch it if the identity node is effectively de-reffing
989 // a ref.
990 return nullptr;
991 }
992 if (IsRecv(e->src()) || IsSwitch(e->src())) {
993 // Don't touch it if the identity is introduced for control flow.
994 // Recv disables all its successors if it receives a dead signal.
995 // When Recv has an outgoing control edge, the current executor
996 // would not disable the destination. The current solution (see
997 // graph_partition.cc) is to add an identity after Recv and change
998 // the control edge to be from this identity node. So the identity
999 // can't be removed.
1000 return nullptr;
1001 }
1002 ret = e;
1003 }
1004 return ret;
1005 }
1006 } // end namespace
1007
RemoveIdentityNodes(Graph * g)1008 bool RemoveIdentityNodes(Graph* g) {
1009 VLOG(2) << "Removing identity nodes";
1010 bool removed_any = false;
1011 gtl::InlinedVector<Node*, 8> matches;
1012 for (Node* n : g->nodes()) {
1013 if (!n->IsIdentity()) continue;
1014 if (!GetTheOnlyDataEdge(n->in_edges())) continue;
1015
1016 // Some identity nodes are used as sink nodes to give names to output
1017 // tensors. These nodes are not going to be executed unless they are in the
1018 // fetch set. But if they are in the fetch set we don't want to remove them.
1019 if (n->out_edges().empty()) continue;
1020
1021 matches.push_back(n);
1022 }
1023 if (!matches.empty()) {
1024 for (Node* n : matches) {
1025 const Edge* in = GetTheOnlyDataEdge(n->in_edges());
1026 for (const Edge* out : n->out_edges()) {
1027 if (out->IsControlEdge()) {
1028 g->AddControlEdge(in->src(), out->dst());
1029 } else {
1030 g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
1031 }
1032 }
1033 VLOG(2) << "Remove Identity: " << n->DebugString();
1034 g->RemoveNode(n);
1035 removed_any = true;
1036 }
1037 }
1038 return removed_any;
1039 }
1040
RemoveListArrayConverter(Graph * g)1041 bool RemoveListArrayConverter(Graph* g) {
1042 VLOG(2) << "Removing list array converter";
1043 gtl::InlinedVector<Node*, 8> matches;
1044 for (Node* n : g->nodes()) {
1045 if ((n->type_string() == "_ListToArray") ||
1046 (n->type_string() == "_ArrayToList")) {
1047 matches.push_back(n);
1048 }
1049 }
1050 bool removed_any = false;
1051 if (!matches.empty()) {
1052 for (Node* n : matches) {
1053 if (n->num_inputs() != n->num_outputs()) {
1054 continue; // Not expected. Skip.
1055 }
1056 gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
1057
1058 // Process input edges first.
1059 Node* input_control_node = nullptr;
1060 for (const Edge* e : n->in_edges()) {
1061 if (e->IsControlEdge()) {
1062 if (input_control_node == nullptr) {
1063 // If node "n" has any control dependencies, adds a no-op
1064 // node (input_control_node) which the additional Identity
1065 // nodes depends on and the input_control_node depends on
1066 // the node "n"s control dependencies.
1067 input_control_node = AddNoOp(g);
1068 }
1069 g->AddControlEdge(e->src(), input_control_node);
1070 } else {
1071 const int index = e->dst_input();
1072 Node** id_node = &identity_nodes[index];
1073 if (*id_node != nullptr) {
1074 LOG(ERROR)
1075 << "RemoveListArrayConverter unexpected duplicated input: "
1076 << e->dst_input();
1077 return removed_any;
1078 }
1079 *id_node = AddIdentity(g, {e->src(), e->src_output()});
1080 }
1081 }
1082
1083 // If node "n" has any control dependencies, the added identity
1084 // nodes should have control dependencies on input_control_node.
1085 if (input_control_node != nullptr) {
1086 for (Node* id : identity_nodes) {
1087 g->AddControlEdge(input_control_node, id);
1088 }
1089 }
1090
1091 Node* output_control_node = nullptr;
1092 for (const Edge* e : n->out_edges()) {
1093 if (e->IsControlEdge()) {
1094 if (output_control_node == nullptr) {
1095 // If node "n" is control-depended upon by other nodes,
1096 // adds a no-op node (output_control_node) which those
1097 // nodes will depend on and output_control_node depends on
1098 // all Identity nodes.
1099 output_control_node = AddNoOp(g);
1100 }
1101 g->AddControlEdge(output_control_node, e->dst());
1102 } else {
1103 Node* id_node = identity_nodes[e->src_output()];
1104 if (id_node == nullptr) {
1105 LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
1106 << e->src_output();
1107 return removed_any;
1108 }
1109 CHECK(id_node);
1110 g->AddEdge(id_node, 0, e->dst(), e->dst_input());
1111 }
1112 }
1113
1114 // If any nodes have control dependencies on node "n", those
1115 // nodes should have control dependencies on
1116 // output_control_node.
1117 if (output_control_node != nullptr) {
1118 for (Node* id : identity_nodes) {
1119 g->AddControlEdge(id, output_control_node);
1120 }
1121 }
1122
1123 g->RemoveNode(n);
1124 removed_any = true;
1125 }
1126 }
1127 return removed_any;
1128 }
1129
1130 // Returns true iff the function '*fbody' can be inlined at 'node'
1131 // based on the type signature of 'node' and 'fbody'.
ValidateInlining(const Node * node,const FunctionBody * fbody)1132 static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
1133 if (static_cast<size_t>(node->num_inputs()) != fbody->arg_types.size()) {
1134 return false;
1135 }
1136 if (static_cast<size_t>(node->num_inputs()) != fbody->arg_nodes.size()) {
1137 return false;
1138 }
1139 if (static_cast<size_t>(node->num_outputs()) != fbody->ret_types.size()) {
1140 return false;
1141 }
1142 if (static_cast<size_t>(node->num_outputs()) != fbody->ret_nodes.size()) {
1143 return false;
1144 }
1145 for (int i = 0; i < node->num_inputs(); ++i) {
1146 if (node->input_type(i) != fbody->arg_types[i]) return false;
1147 }
1148 for (int i = 0; i < node->num_outputs(); ++i) {
1149 if (node->output_type(i) != fbody->ret_types[i]) return false;
1150 }
1151 return true;
1152 }
1153
1154 // Given a "caller" in "graph", which is a function call of a function
1155 // to "fbody". Replaces the "caller" with fbody->graph and connects
1156 // edges properly.
InlineFunctionBody(const FunctionLibraryDefinition & flib_def,Graph * g,Node * caller,const FunctionBody * fbody)1157 void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
1158 Node* caller, const FunctionBody* fbody) {
1159 if (!ValidateInlining(caller, fbody)) {
1160 LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
1161 << DebugString(fbody->graph);
1162 return;
1163 }
1164
1165 // Input edges. For data edges coming into "caller", we first compute the
1166 // <src>:<src_output> for the i-th input in "inputs".
1167 // If "caller" has any input control dependencies, we add a NoOp
1168 // node "input_control_node", which depends on "caller"'s control inputs.
1169 std::vector<Endpoint> inputs(caller->num_inputs());
1170 Node* input_control_node = nullptr;
1171 for (const Edge* e : caller->in_edges()) {
1172 if (e->IsControlEdge()) {
1173 if (input_control_node == nullptr) {
1174 input_control_node = AddNoOp(g);
1175 }
1176 g->AddControlEdge(e->src(), input_control_node);
1177 } else {
1178 inputs[e->dst_input()] = {e->src(), e->src_output()};
1179 }
1180 }
1181
1182 // Duplicate fbody->graph into 'g'. First, we copy the nodes of
1183 // fbody->graph into 'g' except the source and sink nodes. We copy
1184 // edges among nodes in 'fbody->graph'.
1185 //
1186 // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
1187 // remember 'y' in node_map[x->id()].
1188 std::vector<Node*> node_map(fbody->graph->num_node_ids());
1189 Status s;
1190 for (Node* n : fbody->graph->op_nodes()) {
1191 NodeDef ndef = n->def();
1192 ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
1193 ndef.set_device(caller->def().device());
1194 Node* clone = g->AddNode(ndef, &s);
1195 TF_CHECK_OK(s);
1196 node_map[n->id()] = clone;
1197
1198 // If there is an input control node, and one of:
1199 // a) the node has no data or control inputs, or
1200 // b) the node is a function call or SymbolicGradient,
1201 // then add a control edge from the input control node to the clone.
1202 //
1203 // We must not execute any nodes if the original function call would not
1204 // have executed. This is especially critical when the function call is
1205 // inside a control-flow construct like tf.cond(). Case (a) ensures that
1206 // such nodes do not run.
1207 //
1208 // The purpose of case (b) is to ensure that instances of case (a) created
1209 // by further inlining steps also receive the control dependency.
1210 if (input_control_node) {
1211 bool has_inputs = false;
1212 for (const Edge* e : n->in_edges()) {
1213 if (!e->src()->IsSource()) {
1214 has_inputs = true;
1215 break;
1216 }
1217 }
1218 if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr ||
1219 clone->type_string() == "SymbolicGradient") {
1220 g->AddControlEdge(input_control_node, clone);
1221 }
1222 }
1223 }
1224 for (const Edge* e : fbody->graph->edges()) {
1225 if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
1226 e->dst()->IsSink()) {
1227 continue;
1228 }
1229 Node* src_copy = node_map[e->src()->id()];
1230 Node* dst_copy = node_map[e->dst()->id()];
1231 g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1232 }
1233
1234 // Connect input edges.
1235 //
1236 // We create one Identity node for each input. Then, we connect inputs[i] to
1237 // the i-th identity node added. The nodes that previously connected
1238 // to the j-th output of i-th arg node are reconnected to the i-th
1239 // identity node.
1240 //
1241 // The added identity nodes depend on "input_control_node".
1242 for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
1243 Node* arg = node_map[fbody->arg_nodes[i]->id()];
1244 Node* n = AddIdentity(g, inputs[i]);
1245 if (input_control_node) {
1246 g->AddControlEdge(input_control_node, n);
1247 }
1248 for (const Edge* e : arg->out_edges()) {
1249 if (e->IsControlEdge()) {
1250 g->AddControlEdge(n, e->dst());
1251 } else {
1252 g->AddEdge(n, 0, e->dst(), e->dst_input());
1253 }
1254 }
1255 node_map[fbody->arg_nodes[i]->id()] = n;
1256 g->RemoveNode(arg); // 'arg' is disconnected.
1257 }
1258
1259 // Connect output edges.
1260 //
1261 // For i-th return node in fbody->graph, we add in "g" an identity
1262 // node (outputs[i-th]). We then reconnect every incoming edge into
1263 // the i-th return node to the added identity node.
1264 //
1265 // For every data edge coming out of "callee"s i-th output, we
1266 // reconnect it to the i-th identity added above.
1267 //
1268 // If "callee" is control-depended upon by any other nodes, we add a
1269 // NoOp node "output_control_node". "output_control_node" depends on
1270 // all identity nodes added above. And nodes previously depend on
1271 // "callee" is changed to depend on "output_control_node".
1272 std::vector<Node*> outputs(caller->num_outputs());
1273 for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
1274 Node* ret = node_map[fbody->ret_nodes[i]->id()];
1275 Endpoint data; // Data input for the ret node.
1276 for (const Edge* e : ret->in_edges()) {
1277 if (!e->IsControlEdge()) {
1278 data = {e->src(), e->src_output()};
1279 break;
1280 }
1281 }
1282 CHECK(data.node != nullptr);
1283 Node* n = AddIdentity(g, data);
1284 outputs[i] = n;
1285 for (const Edge* e : ret->in_edges()) {
1286 if (e->IsControlEdge()) {
1287 g->AddControlEdge(e->src(), n);
1288 }
1289 }
1290 g->RemoveNode(ret); // 'ret' is disconnected.
1291 }
1292 Node* output_control_node = nullptr;
1293 for (const Edge* e : caller->out_edges()) {
1294 if (e->IsControlEdge()) {
1295 if (output_control_node == nullptr) {
1296 output_control_node = AddNoOp(g);
1297 for (Node* n : outputs) {
1298 g->AddControlEdge(n, output_control_node);
1299 }
1300 }
1301 g->AddControlEdge(output_control_node, e->dst());
1302 } else {
1303 g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
1304 }
1305 }
1306 g->RemoveNode(caller); // 'caller' is replaced with inlined nodes.
1307 }
1308
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph)1309 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
1310 std::vector<std::pair<Node*, const FunctionBody*>> candidates;
1311 const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
1312 for (Node* node : graph->nodes()) {
1313 VLOG(3) << "Expanding " << node->DebugString();
1314 bool noinline;
1315 if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
1316 VLOG(3) << "noinline: " << node->DebugString();
1317 continue;
1318 }
1319 FunctionLibraryRuntime::Handle handle;
1320 Status s = lib->Instantiate(node->type_string(), node->attrs(), &handle);
1321 if (!s.ok()) {
1322 // Either "node" is a primitive op, or the instantiation failed.
1323 if (errors::IsNotFound(s)) {
1324 VLOG(3) << "ExpandInlineFunctions " << s;
1325 } else {
1326 LOG(ERROR) << "ExpandInlineFunctions " << s;
1327 }
1328 continue;
1329 }
1330 const FunctionBody* fbody = lib->GetFunctionBody(handle);
1331 CHECK_NOTNULL(fbody);
1332 candidates.push_back({node, fbody});
1333 }
1334 for (const auto& p : candidates) {
1335 InlineFunctionBody(*fld, graph, p.first, p.second);
1336 }
1337 return !candidates.empty();
1338 }
1339
NewName(const Node * n,bool pretty)1340 string NewName(const Node* n, bool pretty) {
1341 if (pretty) {
1342 return strings::StrCat(n->type_string(), n->id());
1343 } else {
1344 return strings::StrCat("n", n->id());
1345 }
1346 }
1347
1348 // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
1349 // and stash the original NodeDef name as an attr for documentation
1350 // purpose.
ToGraphDef(const Graph * g,GraphDef * gdef,bool pretty)1351 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
1352 // We visit nodes in forward topological sort order, which is a
1353 // possible execution order of the graph.
1354 gtl::InlinedVector<const Edge*, 4> inputs;
1355 gdef->Clear();
1356 gdef->mutable_versions()->CopyFrom(g->versions());
1357
1358 std::vector<Node*> start_nodes;
1359 for (Node* n : g->nodes()) {
1360 if (n->out_edges().empty()) {
1361 start_nodes.push_back(n);
1362 }
1363 }
1364
1365 ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
1366 if (!n->IsOp()) return;
1367 NodeDef* ndef = gdef->add_node();
1368 ndef->set_name(NewName(n, pretty));
1369 ndef->set_op(n->type_string());
1370 for (const auto& attr : n->attrs()) {
1371 (*ndef->mutable_attr())[attr.first] = attr.second;
1372 }
1373 inputs.clear();
1374 inputs.resize(n->num_inputs());
1375 for (const Edge* e : n->in_edges()) {
1376 if (e->IsControlEdge()) {
1377 inputs.push_back(e);
1378 } else {
1379 if (inputs[e->dst_input()] == nullptr) {
1380 inputs[e->dst_input()] = e;
1381 } else {
1382 LOG(WARNING) << "Malformed graph node. multiple input edges: "
1383 << n->DebugString();
1384 }
1385 }
1386 }
1387 // node->name() is merely NodeDef::name, which are not guaranteed
1388 // to be unique and stable after optimization rewrites. Therefore,
1389 // we use "n<node id>" instead.
1390 for (const Edge* e : inputs) {
1391 if (e == nullptr) {
1392 ndef->add_input("unknown");
1393 continue;
1394 }
1395 const string srcname = NewName(e->src(), pretty);
1396 if (!e->src()->IsOp()) {
1397 } else if (e->IsControlEdge()) {
1398 ndef->add_input(strings::StrCat("^", srcname));
1399 } else if (e->src_output() == 0) {
1400 ndef->add_input(srcname);
1401 } else {
1402 ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
1403 }
1404 }
1405 });
1406 }
1407
DebugString(const Graph * g)1408 string DebugString(const Graph* g) {
1409 GraphDef gdef;
1410 ToGraphDef(g, &gdef);
1411 return DebugString(gdef);
1412 }
1413
FunctionBody(const FunctionDef & f,DataTypeSlice arg_t,DataTypeSlice ret_t,Graph * g)1414 FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
1415 DataTypeSlice ret_t, Graph* g)
1416 : fdef(f),
1417 graph(g),
1418 arg_types(arg_t.begin(), arg_t.end()),
1419 ret_types(ret_t.begin(), ret_t.end()) {
1420 this->arg_nodes.resize(arg_types.size());
1421 this->ret_nodes.resize(ret_types.size());
1422 for (Node* n : this->graph->op_nodes()) {
1423 gtl::InlinedVector<Node*, 4>* node_vec;
1424 if (n->type_string() == kRetOp) {
1425 node_vec = &this->ret_nodes;
1426 } else if (n->type_string() == kArgOp) {
1427 node_vec = &this->arg_nodes;
1428 } else {
1429 continue;
1430 }
1431 int index;
1432 TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
1433 CHECK_LE(0, index);
1434 CHECK_LT(index, node_vec->size());
1435 (*node_vec)[index] = n;
1436 }
1437 }
1438
~FunctionBody()1439 FunctionBody::~FunctionBody() { delete this->graph; }
1440
1441 class SymbolicGradientHelper {
1442 public:
SymbolicGradientHelper(const FunctionBody & f)1443 explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1444
~SymbolicGradientHelper()1445 ~SymbolicGradientHelper() { delete gbody_; }
1446
1447 FunctionBody* Compute();
1448
1449 private:
1450 const FunctionBody* fbody_;
1451 FunctionBody* gbody_ = nullptr;
1452
1453 // Makes a copy of fbody_ in gbody_.
1454 void Copy();
1455
1456 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1457 };
1458
Copy()1459 void SymbolicGradientHelper::Copy() {
1460 const Graph& src = *(fbody_->graph);
1461 gbody_->graph = new Graph(src.op_registry());
1462 Graph* dst = gbody_->graph;
1463
1464 std::vector<Node*> node_map(src.num_node_ids());
1465
1466 // Copy the nodes.
1467 node_map[src.source_node()->id()] = dst->source_node();
1468 node_map[src.sink_node()->id()] = dst->sink_node();
1469 for (Node* n : src.op_nodes()) {
1470 node_map[n->id()] = dst->CopyNode(n);
1471 }
1472
1473 // Copy the edges.
1474 for (const Edge* e : src.edges()) {
1475 Node* src_copy = node_map[e->src()->id()];
1476 Node* dst_copy = node_map[e->dst()->id()];
1477 dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1478 }
1479
1480 // Save inputs in copied graph.
1481 CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1482 gbody_->arg_types = fbody_->arg_types;
1483 for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1484 gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1485 }
1486
1487 // Save outputs in copied graph.
1488 CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1489 gbody_->ret_types = fbody_->ret_types;
1490 for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1491 gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1492 }
1493 }
1494
Compute()1495 FunctionBody* SymbolicGradientHelper::Compute() {
1496 CHECK(gbody_ == nullptr);
1497 gbody_ = new FunctionBody;
1498
1499 // Copy fbody_ into gbody_.
1500 Copy();
1501
1502 Graph* g = gbody_->graph;
1503
1504 const int num_y = static_cast<int>(gbody_->ret_nodes.size());
1505
1506 // Populate 'y_node_outputs_' with node function body outputs.
1507 // Populate 'y_grad_nodes' with initial gradient nodes for each return node of
1508 // the original function body (these will be 'arg' nodes in the function
1509 // gradient body).
1510 std::vector<NodeOut> y_node_outputs;
1511 y_node_outputs.reserve(num_y);
1512 std::vector<NodeOut> y_grad_node_outputs;
1513 y_grad_node_outputs.reserve(num_y);
1514 for (int i = 0; i < num_y; ++i) {
1515 Node* y = gbody_->ret_nodes[i];
1516 y_node_outputs.push_back({y, 0});
1517 DCHECK_EQ(y->type_string(), kRetOp);
1518 const DataType dtype = y->input_type(0);
1519 const int index = static_cast<int>(gbody_->arg_nodes.size());
1520 Node* dy = AddArg(g, dtype, index);
1521 gbody_->arg_types.push_back(dtype);
1522 gbody_->arg_nodes.push_back(dy);
1523 y_grad_node_outputs.push_back({dy, 0});
1524 }
1525
1526 // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1527 const size_t num_x = fbody_->arg_nodes.size();
1528 std::vector<NodeOut> x_node_outputs;
1529 x_node_outputs.reserve(num_x);
1530 for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1531 x_node_outputs.push_back({gbody_->arg_nodes[i], 0});
1532 }
1533
1534 // Call AddSymbolicGradients which will add nodes to graph 'g' that
1535 // compute the function gradient (adding an entry in 'x_grad_node_outputs' for
1536 // each node in 'x_node_outputs').
1537 std::vector<NodeOut> x_grad_node_outputs;
1538 TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1539 y_grad_node_outputs, &x_grad_node_outputs,
1540 g));
1541
1542 // Remove the old return nodes from the function body.
1543 for (Node* n : gbody_->ret_nodes) {
1544 g->RemoveNode(n);
1545 }
1546 gbody_->ret_types = fbody_->arg_types;
1547 gbody_->ret_nodes.clear();
1548 // Add new return nodes to the function gradient body for each node
1549 // in 'x_grad_nodes'.
1550 const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1551 for (int i = 0; i < arg_types_size; ++i) {
1552 Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1553 Node* ret = AddRet(g, grad, i);
1554 gbody_->ret_nodes.push_back(ret);
1555 }
1556
1557 auto ret = gbody_;
1558 gbody_ = nullptr;
1559 return ret;
1560 }
1561
SymbolicGradient(const FunctionBody & f)1562 FunctionBody* SymbolicGradient(const FunctionBody& f) {
1563 return SymbolicGradientHelper(f).Compute();
1564 }
1565
FunctionDefToBodyHelper(const FunctionDef & fdef,const AttrSlice & attrs,const FunctionLibraryDefinition * const lib_def,const std::function<Status (const string &,const OpDef **)> & get_func_sig,FunctionBody ** fbody)1566 Status FunctionDefToBodyHelper(
1567 const FunctionDef& fdef, const AttrSlice& attrs,
1568 const FunctionLibraryDefinition* const lib_def,
1569 const std::function<Status(const string&, const OpDef**)>& get_func_sig,
1570 FunctionBody** fbody) {
1571 // Instantiates the function template into a graph def.
1572 InstantiationResult result;
1573 TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
1574
1575 std::unique_ptr<Graph> graph(new Graph(lib_def));
1576 GraphConstructorOptions opts;
1577 opts.allow_internal_ops = true;
1578 opts.expect_device_spec = false;
1579 TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
1580
1581 // Call BuildControlFlowInfo to validate that this function body has
1582 // well-formed control flow.
1583 // NOTE(skyewm): this is usually done in Partition(), but we don't partition
1584 // function bodies. This should be removed if function bodies ever go through
1585 // the Partition() path.
1586 std::vector<ControlFlowInfo> dummy;
1587 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
1588
1589 *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types,
1590 graph.release());
1591 return Status::OK();
1592 }
1593
1594 } // end namespace tensorflow
1595