• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/function.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/graph_optimizer.h"
24 #include "tensorflow/core/common_runtime/memory_types.h"
25 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/control_flow.h"
34 #include "tensorflow/core/graph/gradients.h"
35 #include "tensorflow/core/graph/graph_constructor.h"
36 #include "tensorflow/core/graph/optimizer_cse.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/platform/macros.h"
39 
40 // See core/kernels/function_ops.cc for related kernels.
41 
42 namespace tensorflow {
43 
44 // A few string constant used throughout this module.
45 //
46 // TODO(zhifengc): Dedup some of these constants into
47 // framework/function.h
48 static constexpr const char* const kArgOp = "_Arg";
49 static constexpr const char* const kRetOp = "_Retval";
50 static constexpr const char* const kGradientOp =
51     FunctionLibraryDefinition::kGradientOp;
52 static constexpr const char* const kNodeLabel = "Func";
53 static constexpr const char* const kFuncAttr =
54     FunctionLibraryDefinition::kFuncAttr;
55 
56 // Represents the index-th output of a node.
57 struct Endpoint {
58   Node* node;
59   int index;
60 
61   // Returns the string name represents this endpoint.
nametensorflow::Endpoint62   string name() const {
63     if (index == 0) {
64       return node->name();
65     } else {
66       return strings::StrCat(node->name(), ":", index);
67     }
68   }
69 
dtypetensorflow::Endpoint70   DataType dtype() const { return node->output_type(index); }
71 };
72 
73 struct EndpointHash {
operator ()tensorflow::EndpointHash74   uint64 operator()(const Endpoint& x) const {
75     return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
76                   x.index);
77   }
78 };
79 
80 struct EndpointEq {
operator ()tensorflow::EndpointEq81   bool operator()(const Endpoint& x, const Endpoint& y) const {
82     return (x.node == y.node) && (x.index == y.index);
83   }
84 };
85 
86 // The following Add* routines are used to add a few graph nodes while
87 // functions are transformed.
AddNoOp(Graph * g)88 static Node* AddNoOp(Graph* g) {
89   NodeDef ndef;
90   ndef.set_name(g->NewName(kNodeLabel));
91   ndef.set_op("NoOp");
92   Status s;
93   Node* ret = g->AddNode(ndef, &s);
94   TF_CHECK_OK(s);
95   return ret;
96 }
97 
AddIdentity(Graph * g,Endpoint input)98 static Node* AddIdentity(Graph* g, Endpoint input) {
99   DCHECK_LT(0, input.dtype());
100   NodeDef ndef;
101   ndef.set_name(g->NewName(kNodeLabel));
102   ndef.set_op("Identity");
103   ndef.add_input(input.name());
104   AddNodeAttr("T", BaseType(input.dtype()), &ndef);
105   Status s;
106   Node* ret = g->AddNode(ndef, &s);
107   TF_CHECK_OK(s);
108   g->AddEdge(input.node, input.index, ret, 0);
109   return ret;
110 }
111 
AddArg(Graph * g,DataType dtype,int index)112 static Node* AddArg(Graph* g, DataType dtype, int index) {
113   DCHECK_LT(0, dtype);
114   DCHECK_LT(dtype, DT_FLOAT_REF);
115   NodeDef ndef;
116   ndef.set_name(g->NewName(kNodeLabel));
117   ndef.set_op(kArgOp);
118   AddNodeAttr("T", dtype, &ndef);
119   AddNodeAttr("index", index, &ndef);
120   Status s;
121   Node* ret = g->AddNode(ndef, &s);
122   TF_CHECK_OK(s);
123   return ret;
124 }
125 
AddRet(Graph * g,Endpoint input,int index)126 static Node* AddRet(Graph* g, Endpoint input, int index) {
127   DCHECK_LT(0, input.dtype());
128   DCHECK_LT(input.dtype(), DT_FLOAT_REF);
129   NodeDef ndef;
130   ndef.set_name(g->NewName(kNodeLabel));
131   ndef.set_op(kRetOp);
132   ndef.add_input(input.name());
133   AddNodeAttr("T", input.dtype(), &ndef);
134   AddNodeAttr("index", index, &ndef);
135   Status s;
136   Node* ret = g->AddNode(ndef, &s);
137   TF_CHECK_OK(s);
138   g->AddEdge(input.node, input.index, ret, 0);
139   return ret;
140 }
141 
142 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
143  public:
144   FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
145                              int graph_def_version,
146                              const FunctionLibraryDefinition* lib_def,
147                              const OptimizerOptions& optimizer_options,
148                              CustomKernelCreator custom_kernel_creator,
149                              ProcessFunctionLibraryRuntime* parent);
150 
151   ~FunctionLibraryRuntimeImpl() override;
152 
153   Status Instantiate(const string& function_name, AttrSlice attrs,
154                      const InstantiateOptions& options,
155                      Handle* handle) override;
156 
157   Status ReleaseHandle(Handle handle) override;
158 
159   const FunctionBody* GetFunctionBody(Handle handle) override;
160 
161   Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
162 
163   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
164            std::vector<Tensor>* rets, DoneCallback done) override;
165   // NOTE(mrry): This overload is currently only implemented for local function
166   // execution.
167   // TODO(b/70346412): Implement support for remote function execution when
168   // passing a call frame.
169   void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
170            DoneCallback done) override;
171 
172   bool IsStateful(const string& function) override;
173 
GetFunctionLibraryDefinition() const174   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
175       const override {
176     return base_lib_def_;
177   }
178 
device()179   Device* device() override { return device_; }
env()180   Env* env() override { return env_; }
graph_def_version()181   int graph_def_version() override { return graph_def_version_; }
182 
183   string DebugString(Handle h) override;
184 
185   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
186                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
187                FunctionLibraryRuntime** out_flr) override;
188 
189  private:
190   typedef FunctionLibraryRuntimeImpl ME;
191 
192   const DeviceMgr* const device_mgr_;
193   Device* const device_;
194   Env* const env_;
195   const int graph_def_version_;
196   const FunctionLibraryDefinition* const base_lib_def_;
197   GraphOptimizer optimizer_;
198   const CustomKernelCreator custom_kernel_creator_;
199   const string device_name_;
200 
201   std::function<Status(const string&, const OpDef**)> get_func_sig_;
202   std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
203 
204   mutable mutex mu_;
205 
206   int next_handle_ GUARDED_BY(mu_);
207 
208   // The instantiated and transformed function is encoded as a Graph
209   // object, and an executor is created for the graph.
210   struct Item : public core::RefCounted {
211     const Graph* graph = nullptr;                            // Owned by exec.
212     const FunctionLibraryDefinition* overlay_lib = nullptr;  // Not owned.
213     FunctionBody* func_graph = nullptr;
214     Executor* exec = nullptr;
215 
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item216     ~Item() override {
217       delete this->func_graph;
218       delete this->exec;
219     }
220   };
221   std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_);
222 
223   ProcessFunctionLibraryRuntime* parent_ = nullptr;  // not owned.
224 
225   Status CreateKernel(const NodeDef& ndef,
226                       const FunctionLibraryDefinition* lib_def,
227                       OpKernel** kernel);
228   Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
229                            const FunctionLibraryDefinition* lib_def,
230                            FunctionBody** fbody);
231   Status CreateItem(Handle handle, Item** item);
232   Status GetOrCreateItem(Handle handle, Item** item);
233   Status InstantiateSymbolicGradient(const NameAttrList& func,
234                                      const FunctionLibraryDefinition* lib_def,
235                                      FunctionBody** g_body);
236   bool IsLocalTarget(const InstantiateOptions& options);
237   AttrValueMap FixAttrs(const AttrSlice& attrs);
238   void RunRemote(const Options& opts, Handle handle,
239                  gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
240                  Executor::Args* exec_args, Item* item, DoneCallback done);
241 
242   TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
243 };
244 
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,ProcessFunctionLibraryRuntime * parent)245 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
246     const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
247     const FunctionLibraryDefinition* lib_def,
248     const OptimizerOptions& optimizer_options,
249     CustomKernelCreator custom_kernel_creator,
250     ProcessFunctionLibraryRuntime* parent)
251     : device_mgr_(dmgr),
252       device_(device),
253       env_(env),
254       graph_def_version_(graph_def_version),
255       base_lib_def_(lib_def),
256       optimizer_(optimizer_options),
257       custom_kernel_creator_(std::move(custom_kernel_creator)),
258       device_name_(device_ == nullptr
259                        ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
260                        : device_->name()),
261       next_handle_(0),
262       parent_(parent) {
263   get_func_sig_ = [this](const string& op, const OpDef** sig) {
264     return base_lib_def_->LookUpOpDef(op, sig);
265   };
266   create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
267     return CreateKernel(ndef, kernel);
268   };
269 }
270 
~FunctionLibraryRuntimeImpl()271 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
272   // The most common patterns of FLR usage don't require the caller to
273   // explicitly release handles. As a result, we try to unref each item until
274   // it's erased.
275   for (auto item : items_) {
276     if (item.second) {
277       while (!item.second->Unref()) {
278       }
279     }
280   }
281 }
282 
283 // An asynchronous op kernel which executes an instantiated function
284 // defined in a library.
285 class CallOp : public AsyncOpKernel {
286  public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)287   CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
288       : AsyncOpKernel(ctx), handle_(handle) {}
289 
~CallOp()290   ~CallOp() override {}
291 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)292   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
293     FunctionLibraryRuntime* lib = ctx->function_library();
294     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
295                       errors::Internal("No function library is provided."),
296                       done);
297     FunctionLibraryRuntime::Options opts;
298     opts.step_id = ctx->step_id();
299     opts.rendezvous = ctx->rendezvous();
300     opts.cancellation_manager = ctx->cancellation_manager();
301     opts.step_container = ctx->step_container();
302     opts.stats_collector = ctx->stats_collector();
303     opts.runner = ctx->runner();
304     std::vector<Tensor> args;
305     args.reserve(ctx->num_inputs());
306     for (int i = 0; i < ctx->num_inputs(); ++i) {
307       args.push_back(ctx->input(i));
308     }
309     std::vector<Tensor>* rets = new std::vector<Tensor>;
310     lib->Run(opts, handle_, args, rets,
311              [ctx, done, rets](const Status& status) {
312                if (!status.ok()) {
313                  ctx->SetStatus(status);
314                } else {
315                  const int ret_size = static_cast<int>(rets->size());
316                  CHECK_EQ(ret_size, ctx->num_outputs());
317                  for (int i = 0; i < ret_size; ++i) {
318                    ctx->set_output(i, (*rets)[i]);
319                  }
320                }
321                delete rets;
322                done();
323              });
324   }
325 
326  private:
327   FunctionLibraryRuntime::Handle handle_;
328 
329   TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
330 };
331 
GetFunctionBody(Handle h)332 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
333   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
334   if (local_handle == kInvalidLocalHandle) {
335     LOG(ERROR) << "Could not find Handle: " << h
336                << " on device: " << device_name_;
337     return nullptr;
338   }
339 
340   mutex_lock l(mu_);
341   CHECK_EQ(1, items_.count(local_handle));
342   return items_[local_handle]->func_graph;
343 }
344 
CreateKernel(const NodeDef & ndef,OpKernel ** kernel)345 Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
346                                                 OpKernel** kernel) {
347   return CreateKernel(ndef, base_lib_def_, kernel);
348 }
349 
CreateKernel(const NodeDef & ndef,const FunctionLibraryDefinition * lib_def,OpKernel ** kernel)350 Status FunctionLibraryRuntimeImpl::CreateKernel(
351     const NodeDef& ndef, const FunctionLibraryDefinition* lib_def,
352     OpKernel** kernel) {
353   // If a custom kernel creator is given, try that.
354   Status s;
355   if (custom_kernel_creator_) {
356     std::unique_ptr<OpKernel> ret;
357     s = custom_kernel_creator_(this, ndef, &ret);
358     if (s.ok()) {
359       *kernel = ret.release();
360       return s;
361     } else {
362       VLOG(2) << "Custom creator error: " << s;
363       // Falls through.
364       s = Status::OK();
365     }
366   }
367 
368   if (lib_def->Find(ndef.op()) == nullptr) {
369     // A primitive operation. Creates the registered kernel.
370     return CreateNonCachedKernel(device_, this, ndef, graph_def_version_,
371                                  kernel);
372   }
373 
374   // Try to instantiate this function for the func/attr. Maybe it's
375   // cached already.
376   InstantiateOptions options;
377   if (lib_def != base_lib_def_) {
378     options.overlay_lib = lib_def;
379   }
380   Handle handle;
381   TF_RETURN_IF_ERROR(
382       Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle));
383 
384   const FunctionBody* fbody = GetFunctionBody(handle);
385   CHECK_NOTNULL(fbody);
386 
387   // TODO(zhifengc): For now, we assume int32 and resources are always on host
388   // memory and other types are always on device memory. We should do type
389   // inference over function body to derive the correct input/output memory
390   // types.
391   MemoryTypeVector input_memory_types;
392   for (const auto& t : fbody->arg_types) {
393     input_memory_types.push_back(
394         (t == DT_INT32 || t == DT_RESOURCE) ? HOST_MEMORY : DEVICE_MEMORY);
395   }
396   MemoryTypeVector output_memory_types;
397   for (const auto& t : fbody->ret_types) {
398     output_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY);
399   }
400 
401   // Constructs a CallOp kernel for running the instantiated function.
402   auto device_type = DeviceType(device_->attributes().device_type());
403   OpKernelConstruction construction(
404       device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
405       &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
406       fbody->ret_types, output_memory_types, graph_def_version_, &s);
407   *kernel = new CallOp(handle, &construction);
408   if (!s.ok()) {
409     delete *kernel;
410   }
411   return s;
412 }
413 
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,FunctionBody ** fbody)414 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
415     const FunctionDef& fdef, AttrSlice attrs,
416     const FunctionLibraryDefinition* lib_def, FunctionBody** fbody) {
417   if (lib_def == base_lib_def_) {
418     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
419   } else {
420     auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
421       return lib_def->LookUpOpDef(op, sig);
422     };
423     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
424   }
425 }
426 
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,FunctionBody ** g_body)427 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
428     const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
429     FunctionBody** g_body) {
430   const FunctionDef* fdef = lib_def->Find(func.name());
431   if (fdef == nullptr) {
432     // f is a primitive op.
433     gradient::Creator creator;
434     TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
435     if (creator == nullptr) {
436       return errors::InvalidArgument("No gradient is defined for ",
437                                      func.name());
438     }
439     FunctionDef grad_fdef;
440     // TODO(josh11b): Should filter out the attrs from func that aren't used
441     // by the gradient function.
442     TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
443     TF_RETURN_IF_ERROR(
444         FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
445   } else {
446     // f is a user-defined function.
447     InstantiateOptions options;
448     if (lib_def != base_lib_def_) {
449       options.overlay_lib = lib_def;
450     }
451     Handle f_handle;
452     TF_RETURN_IF_ERROR(
453         Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
454     const FunctionBody* f_body = GetFunctionBody(f_handle);
455     CHECK_NOTNULL(f_body);
456     *g_body = SymbolicGradient(*f_body);
457   }
458   return Status::OK();
459 }
460 
IsLocalTarget(const InstantiateOptions & options)461 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
462     const InstantiateOptions& options) {
463   if (device_ == nullptr) return true;
464   if (options.target.empty()) return true;
465   Device* target_device;
466   if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
467     return false;
468   }
469   return target_device == device_;
470 }
471 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)472 Status FunctionLibraryRuntimeImpl::Instantiate(
473     const string& function_name, AttrSlice attrs,
474     const InstantiateOptions& options, Handle* handle) {
475   if (!IsLocalTarget(options)) {
476     return parent_->Instantiate(function_name, attrs, options, handle);
477   }
478 
479   // Since this is a local target, ensure that the local `device_name_` appears
480   // in the canonical key.
481   InstantiateOptions options_copy(options);
482   options_copy.target = device_name_;
483   const string key = Canonicalize(function_name, attrs, options_copy);
484   *handle = parent_->GetHandle(key);
485   if (*handle != kInvalidHandle) {
486     mutex_lock l(mu_);
487     items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
488     return Status::OK();
489   }
490 
491   Status s;
492   const FunctionLibraryDefinition* lib_def =
493       options.overlay_lib ? options.overlay_lib : base_lib_def_;
494   FunctionBody* fbody = nullptr;
495   if (function_name == kGradientOp) {
496     const AttrValue* f = attrs.Find(kFuncAttr);
497     if (f == nullptr) {
498       return errors::InvalidArgument("SymbolicGradient is missing attr: f");
499     }
500     const auto& func = f->func();
501     if (func.name() == kGradientOp) {
502       return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
503     }
504     const string grad = lib_def->FindGradient(func.name());
505     if (!grad.empty()) {
506       return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
507     }
508     TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
509   } else {
510     const FunctionDef* fdef = lib_def->Find(function_name);
511     if (fdef == nullptr) {
512       return errors::NotFound("Function ", function_name, " is not defined.");
513     }
514     TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
515   }
516 
517   {
518     mutex_lock l(mu_);
519     *handle = parent_->GetHandle(key);
520     if (*handle != kInvalidHandle) {
521       delete fbody;
522       items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
523     } else {
524       *handle = parent_->AddHandle(key, device_name_, next_handle_);
525       Item* item = new Item;
526       item->func_graph = fbody;
527       item->overlay_lib = options.overlay_lib;
528       items_.insert({next_handle_, item});
529       next_handle_++;
530     }
531   }
532   return Status::OK();
533 }
534 
ReleaseHandle(Handle handle)535 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
536   if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
537     return parent_->ReleaseHandle(handle);
538   }
539 
540   LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
541   mutex_lock l(mu_);
542   CHECK_EQ(1, items_.count(h));
543   Item* item = items_[h];
544   if (item->Unref()) {
545     items_.erase(h);
546     TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
547   }
548   return Status::OK();
549 }
550 
DumpGraph(StringPiece label,const Graph * g)551 void DumpGraph(StringPiece label, const Graph* g) {
552   // TODO(zhifengc): Change Graph to record #nodes.
553   VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
554           << g->num_edges();
555   if (VLOG_IS_ON(2)) {
556     for (const auto& line : str_util::Split(DebugString(g), '\n')) {
557       VLOG(2) << "|| " << line;
558     }
559   }
560 }
561 
OptimizeGraph(FunctionLibraryRuntime * lib,std::unique_ptr<Graph> * g)562 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
563   OptimizerOptions opts;
564   opts.set_do_common_subexpression_elimination(true);
565   opts.set_do_function_inlining(true);
566   opts.set_do_constant_folding(true);
567   GraphOptimizer optimizer(opts);
568   optimizer.Optimize(lib, lib->env(), lib->device(), g, /*shape_map=*/nullptr);
569 }
570 
571 namespace {
572 // Removes all stateless nodes that do not contribute to a return
573 // value from the function body.  Unlike `RemoveDeadNodes()`, which is
574 // triggered by `OptimizerOptions.do_function_inlining`, this pass
575 // ignores the SINK node, from which (by definition) all nodes are
576 // reverse reachable.
PruneFunctionBody(Graph * g)577 void PruneFunctionBody(Graph* g) {
578   VLOG(2) << "Pruning function body";
579   std::unordered_set<const Node*> nodes;
580   for (auto n : g->nodes()) {
581     // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
582     // to the seed set of `nodes`.
583     // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
584     // still needed. It would be preferable to prune entire loops and/or
585     // conditionals if they are not used in the graph.
586     if (n->IsControlFlow() || n->op_def().is_stateful()) {
587       nodes.insert(n);
588     }
589   }
590   bool changed = PruneForReverseReachability(g, std::move(nodes));
591   if (changed) {
592     FixupSourceAndSinkEdges(g);
593   }
594 }
595 }  // namespace
596 
CreateItem(Handle handle,Item ** item)597 Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
598   const FunctionBody* fbody;
599   const FunctionLibraryDefinition* lib_def;
600   {
601     mutex_lock l(mu_);
602     fbody = (*item)->func_graph;
603     lib_def = (*item)->overlay_lib;
604   }
605   if (!lib_def) {
606     lib_def = base_lib_def_;
607   }
608   std::unique_ptr<Graph> g(new Graph(lib_def));
609   CopyGraph(*fbody->graph, g.get());
610 
611   PruneFunctionBody(g.get());
612   optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
613   TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
614                                        device()->name(), g.get()));
615 
616   // Creates an executor based on the g.  This must be done without
617   // holding mu_ because create_kernel_ calls back into the library.
618   LocalExecutorParams params;
619   params.device = device_;
620   params.function_library = this;
621   if (lib_def == base_lib_def_) {
622     params.create_kernel = create_kernel_;
623   } else {
624     params.create_kernel = [this, lib_def](const NodeDef& ndef,
625                                            OpKernel** kernel) {
626       return CreateKernel(ndef, lib_def, kernel);
627     };
628   }
629   params.delete_kernel = [](OpKernel* kernel) {
630     DeleteNonCachedKernel(kernel);
631   };
632   Graph* graph = g.get();
633   Executor* exec;
634   TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec));
635 
636   {
637     // Guard item since it is already inserted in items_.
638     mutex_lock l(mu_);
639     if ((*item)->exec) {
640       delete exec;
641     } else {
642       (*item)->graph = graph;
643       (*item)->exec = exec;
644     }
645   }
646   return Status::OK();
647 }
648 
GetOrCreateItem(Handle handle,Item ** item)649 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
650   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
651   {
652     mutex_lock l(mu_);
653     if (items_.count(local_handle) == 0) {
654       return errors::NotFound("Function handle ", handle,
655                               " is not valid. Likely an internal error.");
656     }
657     *item = items_[local_handle];
658     if ((*item)->exec != nullptr) {
659       return Status::OK();
660     }
661   }
662   // NOTE: We need to call CreateItem out of mu_ because creating an
663   // executor needs to call CreateKernel.
664   return CreateItem(handle, item);
665 }
666 
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Executor::Args * exec_args,Item * item,DoneCallback done)667 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
668                                            gtl::ArraySlice<Tensor> args,
669                                            std::vector<Tensor>* rets,
670                                            Executor::Args* exec_args,
671                                            Item* item, DoneCallback done) {
672   DCHECK(exec_args->call_frame == nullptr);
673   string target_device = parent_->GetDeviceName(handle);
674   string source_device = opts.source_device;
675   Rendezvous* rendezvous = opts.rendezvous;
676   DeviceContext* device_context;
677   Status s = parent_->GetDeviceContext(target_device, &device_context);
678   if (!s.ok()) {
679     delete exec_args;
680     done(s);
681     return;
682   }
683   int64 src_incarnation, target_incarnation;
684   s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
685   s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
686   if (!s.ok()) {
687     delete exec_args;
688     done(s);
689     return;
690   }
691 
692   const FunctionBody* fbody = GetFunctionBody(handle);
693   FunctionCallFrame* frame =
694       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
695   exec_args->call_frame = frame;
696   if (!s.ok()) {
697     delete frame;
698     delete exec_args;
699     done(s);
700     return;
701   }
702 
703   // The ProcFLR sends the arguments to the function from the source_device to
704   // the target_device. So here we receive those arguments. Similarly, when the
705   // computation is done and stored in *rets, we send the return values back
706   // to the source_device (caller) so that the ProcFLR can receive them later.
707   std::vector<Tensor>* remote_args = new std::vector<Tensor>;
708   ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
709       source_device, target_device, "arg_", src_incarnation, args.size(),
710       device_context, {}, rendezvous, remote_args,
711       [frame, remote_args, item, source_device, target_device,
712        target_incarnation, rendezvous, device_context, rets, done,
713        exec_args](const Status& status) {
714         Status s = status;
715         if (s.ok()) {
716           s = frame->SetArgs(*remote_args);
717         }
718         if (!s.ok()) {
719           delete frame;
720           delete remote_args;
721           delete exec_args;
722           done(s);
723           return;
724         }
725         item->exec->RunAsync(
726             *exec_args, [item, frame, rets, done, source_device, target_device,
727                          target_incarnation, rendezvous, device_context,
728                          remote_args, exec_args](const Status& status) {
729               Status s = status;
730               if (s.ok()) {
731                 s = frame->ConsumeRetvals(rets);
732               }
733               delete frame;
734               if (!s.ok()) {
735                 delete remote_args;
736                 delete exec_args;
737                 done(s);
738                 return;
739               }
740               s = ProcessFunctionLibraryRuntime::SendTensors(
741                   target_device, source_device, "ret_", target_incarnation,
742                   *rets, device_context, {}, rendezvous);
743               delete remote_args;
744               delete exec_args;
745               done(s);
746             });
747       });
748 }
749 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)750 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
751                                      gtl::ArraySlice<Tensor> args,
752                                      std::vector<Tensor>* rets,
753                                      DoneCallback done) {
754   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
755     done(errors::Cancelled(""));
756     return;
757   }
758   Options run_opts = opts;
759   if (opts.create_rendezvous) {
760     Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
761     run_opts.rendezvous = rendezvous;
762     run_opts.create_rendezvous = false;
763     done = [done, rendezvous](const Status& status) {
764       rendezvous->Unref();
765       done(status);
766     };
767   }
768   if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
769     parent_->Run(run_opts, handle, args, rets, done);
770     return;
771   }
772 
773   DCHECK(run_opts.runner != nullptr);
774 
775   Executor::Args* exec_args = new Executor::Args;
776   // Inherit the step_id from the caller.
777   exec_args->step_id = run_opts.step_id;
778   exec_args->rendezvous = run_opts.rendezvous;
779   exec_args->stats_collector = run_opts.stats_collector;
780   exec_args->cancellation_manager = run_opts.cancellation_manager;
781   exec_args->step_container = run_opts.step_container;
782   exec_args->runner = *run_opts.runner;
783 
784   Item* item = nullptr;
785   Status s = GetOrCreateItem(handle, &item);
786   if (!s.ok()) {
787     delete exec_args;
788     done(s);
789     return;
790   }
791 
792   if (run_opts.remote_execution) {
793     // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
794     RunRemote(run_opts, handle, args, rets, exec_args, item, done);
795     return;
796   }
797 
798   const FunctionBody* fbody = GetFunctionBody(handle);
799   FunctionCallFrame* frame =
800       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
801   exec_args->call_frame = frame;
802   s = frame->SetArgs(args);
803   if (!s.ok()) {
804     delete frame;
805     delete exec_args;
806     done(s);
807     return;
808   }
809 
810   item->exec->RunAsync(
811       // Executor args
812       *exec_args,
813       // Done callback.
814       [item, frame, rets, done, exec_args](const Status& status) {
815         Status s = status;
816         if (s.ok()) {
817           s = frame->ConsumeRetvals(rets);
818         }
819         delete frame;
820         delete exec_args;
821         done(s);
822       });
823 }
824 
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)825 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
826                                      CallFrameInterface* frame,
827                                      DoneCallback done) {
828   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
829     done(errors::Cancelled(""));
830     return;
831   }
832   if (!parent_->IsInstantiatedOnDevice(device_name_, handle) ||
833       opts.remote_execution) {
834     done(errors::Unimplemented("Remote calling with CallFrameInterface"));
835     return;
836   }
837 
838   Options run_opts = opts;
839   if (opts.create_rendezvous) {
840     Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
841     run_opts.rendezvous = rendezvous;
842     run_opts.create_rendezvous = false;
843     done = std::bind(
844         [rendezvous](DoneCallback done,
845                      // Begin unbound arguments.
846                      const Status& status) {
847           rendezvous->Unref();
848           done(status);
849         },
850         std::move(done), std::placeholders::_1);
851   }
852 
853   Item* item = nullptr;
854   Status s = GetOrCreateItem(handle, &item);
855   if (!s.ok()) {
856     done(s);
857     return;
858   }
859   DCHECK(run_opts.runner != nullptr);
860 
861   Executor::Args* exec_args = new Executor::Args;
862   // Inherit the step_id from the caller.
863   exec_args->step_id = run_opts.step_id;
864   exec_args->rendezvous = run_opts.rendezvous;
865   exec_args->stats_collector = run_opts.stats_collector;
866   exec_args->cancellation_manager = run_opts.cancellation_manager;
867   exec_args->step_container = run_opts.step_container;
868   exec_args->runner = *run_opts.runner;
869   exec_args->call_frame = frame;
870 
871   item->exec->RunAsync(
872       // Executor args
873       *exec_args,
874       // Done callback.
875       std::bind(
876           [item, frame, exec_args](DoneCallback done,
877                                    // Start unbound arguments.
878                                    const Status& status) {
879             delete exec_args;
880             done(status);
881           },
882           std::move(done), std::placeholders::_1));
883 }
884 
IsStateful(const string & func)885 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
886   const OpDef* op_def;
887   const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
888   return s.ok() && op_def->is_stateful();
889 }
890 
DebugString(Handle handle)891 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
892   Item* item = nullptr;
893   Status s = GetOrCreateItem(handle, &item);
894   if (s.ok()) {
895     return tensorflow::DebugString(item->graph);
896   } else {
897     return s.ToString();
898   }
899 }
900 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr)901 Status FunctionLibraryRuntimeImpl::Clone(
902     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
903     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
904     FunctionLibraryRuntime** out_flr) {
905   TF_RETURN_IF_ERROR(
906       parent_->Clone(env_, graph_def_version_, optimizer_.options(),
907                      custom_kernel_creator_, out_lib_def, out_pflr));
908   *out_flr = (*out_pflr)->GetFLR(device_->name());
909   if (out_flr != nullptr) {
910     return Status::OK();
911   } else {
912     return errors::Internal("Cloning FunctionLibraryRuntime failed.");
913   }
914 }
915 
916 namespace {
917 
918 struct CustomCreatorSingleton {
919   mutex mu;
920   CustomKernelCreator custom_creator = nullptr;
921 
Settensorflow::__anon849d4f9f0e11::CustomCreatorSingleton922   void Set(CustomKernelCreator cb) {
923     mutex_lock l(mu);
924     custom_creator = std::move(cb);
925   }
926 
Gettensorflow::__anon849d4f9f0e11::CustomCreatorSingleton927   CustomKernelCreator Get() {
928     mutex_lock l(mu);
929     return custom_creator;
930   }
931 };
932 
GetCustomCreatorSingleton()933 CustomCreatorSingleton* GetCustomCreatorSingleton() {
934   static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
935   return ccs;
936 }
937 
938 }  // namespace
939 
RegisterDefaultCustomKernelCreator(CustomKernelCreator cb)940 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
941   GetCustomCreatorSingleton()->Set(std::move(cb));
942 }
943 
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,ProcessFunctionLibraryRuntime * parent)944 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
945     const DeviceMgr* device_mgr, Env* env, Device* device,
946     int graph_def_version, const FunctionLibraryDefinition* lib_def,
947     const OptimizerOptions& optimizer_options,
948     CustomKernelCreator custom_kernel_creator,
949     ProcessFunctionLibraryRuntime* parent) {
950   return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
951       device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
952       std::move(custom_kernel_creator), parent));
953 }
954 
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,ProcessFunctionLibraryRuntime * parent)955 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
956     const DeviceMgr* device_mgr, Env* env, Device* device,
957     int graph_def_version, const FunctionLibraryDefinition* lib_def,
958     const OptimizerOptions& optimizer_options,
959     ProcessFunctionLibraryRuntime* parent) {
960   return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
961                                    lib_def, optimizer_options,
962                                    GetCustomCreatorSingleton()->Get(), parent);
963 }
964 
RemoveDeadNodes(Graph * g)965 bool RemoveDeadNodes(Graph* g) {
966   VLOG(2) << "Removing dead nodes";
967   std::unordered_set<const Node*> nodes;
968   for (auto n : g->nodes()) {
969     if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
970         n->op_def().is_stateful()) {
971       nodes.insert(n);
972     }
973   }
974   return PruneForReverseReachability(g, std::move(nodes));
975 }
976 
977 namespace {
978 // If 'edges' contains only 1 non-control edge, returns it. Otherwise,
979 // returns a nullptr.
GetTheOnlyDataEdge(const EdgeSet & edges)980 const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
981   const Edge* ret = nullptr;
982   for (const Edge* e : edges) {
983     if (e->IsControlEdge() || ret) {
984       // Don't touch it if there is a control edge.
985       return nullptr;
986     }
987     if (IsRefType(e->src()->output_type(e->src_output()))) {
988       // Don't touch it if the identity node is effectively de-reffing
989       // a ref.
990       return nullptr;
991     }
992     if (IsRecv(e->src()) || IsSwitch(e->src())) {
993       // Don't touch it if the identity is introduced for control flow.
994       // Recv disables all its successors if it receives a dead signal.
995       // When Recv has an outgoing control edge, the current executor
996       // would not disable the destination. The current solution (see
997       // graph_partition.cc) is to add an identity after Recv and change
998       // the control edge to be from this identity node. So the identity
999       // can't be removed.
1000       return nullptr;
1001     }
1002     ret = e;
1003   }
1004   return ret;
1005 }
1006 }  // end namespace
1007 
RemoveIdentityNodes(Graph * g)1008 bool RemoveIdentityNodes(Graph* g) {
1009   VLOG(2) << "Removing identity nodes";
1010   bool removed_any = false;
1011   gtl::InlinedVector<Node*, 8> matches;
1012   for (Node* n : g->nodes()) {
1013     if (!n->IsIdentity()) continue;
1014     if (!GetTheOnlyDataEdge(n->in_edges())) continue;
1015 
1016     // Some identity nodes are used as sink nodes to give names to output
1017     // tensors. These nodes are not going to be executed unless they are in the
1018     // fetch set. But if they are in the fetch set we don't want to remove them.
1019     if (n->out_edges().empty()) continue;
1020 
1021     matches.push_back(n);
1022   }
1023   if (!matches.empty()) {
1024     for (Node* n : matches) {
1025       const Edge* in = GetTheOnlyDataEdge(n->in_edges());
1026       for (const Edge* out : n->out_edges()) {
1027         if (out->IsControlEdge()) {
1028           g->AddControlEdge(in->src(), out->dst());
1029         } else {
1030           g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
1031         }
1032       }
1033       VLOG(2) << "Remove Identity: " << n->DebugString();
1034       g->RemoveNode(n);
1035       removed_any = true;
1036     }
1037   }
1038   return removed_any;
1039 }
1040 
RemoveListArrayConverter(Graph * g)1041 bool RemoveListArrayConverter(Graph* g) {
1042   VLOG(2) << "Removing list array converter";
1043   gtl::InlinedVector<Node*, 8> matches;
1044   for (Node* n : g->nodes()) {
1045     if ((n->type_string() == "_ListToArray") ||
1046         (n->type_string() == "_ArrayToList")) {
1047       matches.push_back(n);
1048     }
1049   }
1050   bool removed_any = false;
1051   if (!matches.empty()) {
1052     for (Node* n : matches) {
1053       if (n->num_inputs() != n->num_outputs()) {
1054         continue;  // Not expected. Skip.
1055       }
1056       gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
1057 
1058       // Process input edges first.
1059       Node* input_control_node = nullptr;
1060       for (const Edge* e : n->in_edges()) {
1061         if (e->IsControlEdge()) {
1062           if (input_control_node == nullptr) {
1063             // If node "n" has any control dependencies, adds a no-op
1064             // node (input_control_node) which the additional Identity
1065             // nodes depends on and the input_control_node depends on
1066             // the node "n"s control dependencies.
1067             input_control_node = AddNoOp(g);
1068           }
1069           g->AddControlEdge(e->src(), input_control_node);
1070         } else {
1071           const int index = e->dst_input();
1072           Node** id_node = &identity_nodes[index];
1073           if (*id_node != nullptr) {
1074             LOG(ERROR)
1075                 << "RemoveListArrayConverter unexpected duplicated input: "
1076                 << e->dst_input();
1077             return removed_any;
1078           }
1079           *id_node = AddIdentity(g, {e->src(), e->src_output()});
1080         }
1081       }
1082 
1083       // If node "n" has any control dependencies, the added identity
1084       // nodes should have control dependencies on input_control_node.
1085       if (input_control_node != nullptr) {
1086         for (Node* id : identity_nodes) {
1087           g->AddControlEdge(input_control_node, id);
1088         }
1089       }
1090 
1091       Node* output_control_node = nullptr;
1092       for (const Edge* e : n->out_edges()) {
1093         if (e->IsControlEdge()) {
1094           if (output_control_node == nullptr) {
1095             // If node "n" is control-depended upon by other nodes,
1096             // adds a no-op node (output_control_node) which those
1097             // nodes will depend on and output_control_node depends on
1098             // all Identity nodes.
1099             output_control_node = AddNoOp(g);
1100           }
1101           g->AddControlEdge(output_control_node, e->dst());
1102         } else {
1103           Node* id_node = identity_nodes[e->src_output()];
1104           if (id_node == nullptr) {
1105             LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
1106                        << e->src_output();
1107             return removed_any;
1108           }
1109           CHECK(id_node);
1110           g->AddEdge(id_node, 0, e->dst(), e->dst_input());
1111         }
1112       }
1113 
1114       // If any nodes have control dependencies on node "n", those
1115       // nodes should have control dependencies on
1116       // output_control_node.
1117       if (output_control_node != nullptr) {
1118         for (Node* id : identity_nodes) {
1119           g->AddControlEdge(id, output_control_node);
1120         }
1121       }
1122 
1123       g->RemoveNode(n);
1124       removed_any = true;
1125     }
1126   }
1127   return removed_any;
1128 }
1129 
1130 // Returns true iff the function '*fbody' can be inlined at 'node'
1131 // based on the type signature of 'node' and 'fbody'.
ValidateInlining(const Node * node,const FunctionBody * fbody)1132 static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
1133   if (static_cast<size_t>(node->num_inputs()) != fbody->arg_types.size()) {
1134     return false;
1135   }
1136   if (static_cast<size_t>(node->num_inputs()) != fbody->arg_nodes.size()) {
1137     return false;
1138   }
1139   if (static_cast<size_t>(node->num_outputs()) != fbody->ret_types.size()) {
1140     return false;
1141   }
1142   if (static_cast<size_t>(node->num_outputs()) != fbody->ret_nodes.size()) {
1143     return false;
1144   }
1145   for (int i = 0; i < node->num_inputs(); ++i) {
1146     if (node->input_type(i) != fbody->arg_types[i]) return false;
1147   }
1148   for (int i = 0; i < node->num_outputs(); ++i) {
1149     if (node->output_type(i) != fbody->ret_types[i]) return false;
1150   }
1151   return true;
1152 }
1153 
1154 // Given a "caller" in "graph", which is a function call of a function
1155 // to "fbody". Replaces the "caller" with fbody->graph and connects
1156 // edges properly.
InlineFunctionBody(const FunctionLibraryDefinition & flib_def,Graph * g,Node * caller,const FunctionBody * fbody)1157 void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
1158                         Node* caller, const FunctionBody* fbody) {
1159   if (!ValidateInlining(caller, fbody)) {
1160     LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
1161                  << DebugString(fbody->graph);
1162     return;
1163   }
1164 
1165   // Input edges. For data edges coming into "caller", we first compute the
1166   // <src>:<src_output> for the i-th input in "inputs".
1167   // If "caller" has any input control dependencies, we add a NoOp
1168   // node "input_control_node", which depends on "caller"'s control inputs.
1169   std::vector<Endpoint> inputs(caller->num_inputs());
1170   Node* input_control_node = nullptr;
1171   for (const Edge* e : caller->in_edges()) {
1172     if (e->IsControlEdge()) {
1173       if (input_control_node == nullptr) {
1174         input_control_node = AddNoOp(g);
1175       }
1176       g->AddControlEdge(e->src(), input_control_node);
1177     } else {
1178       inputs[e->dst_input()] = {e->src(), e->src_output()};
1179     }
1180   }
1181 
1182   // Duplicate fbody->graph into 'g'.  First, we copy the nodes of
1183   // fbody->graph into 'g' except the source and sink nodes.  We copy
1184   // edges among nodes in 'fbody->graph'.
1185   //
1186   // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
1187   // remember 'y' in node_map[x->id()].
1188   std::vector<Node*> node_map(fbody->graph->num_node_ids());
1189   Status s;
1190   for (Node* n : fbody->graph->op_nodes()) {
1191     NodeDef ndef = n->def();
1192     ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
1193     ndef.set_device(caller->def().device());
1194     Node* clone = g->AddNode(ndef, &s);
1195     TF_CHECK_OK(s);
1196     node_map[n->id()] = clone;
1197 
1198     // If there is an input control node, and one of:
1199     // a) the node has no data or control inputs, or
1200     // b) the node is a function call or SymbolicGradient,
1201     // then add a control edge from the input control node to the clone.
1202     //
1203     // We must not execute any nodes if the original function call would not
1204     // have executed. This is especially critical when the function call is
1205     // inside a control-flow construct like tf.cond(). Case (a) ensures that
1206     // such nodes do not run.
1207     //
1208     // The purpose of case (b) is to ensure that instances of case (a) created
1209     // by further inlining steps also receive the control dependency.
1210     if (input_control_node) {
1211       bool has_inputs = false;
1212       for (const Edge* e : n->in_edges()) {
1213         if (!e->src()->IsSource()) {
1214           has_inputs = true;
1215           break;
1216         }
1217       }
1218       if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr ||
1219           clone->type_string() == "SymbolicGradient") {
1220         g->AddControlEdge(input_control_node, clone);
1221       }
1222     }
1223   }
1224   for (const Edge* e : fbody->graph->edges()) {
1225     if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
1226         e->dst()->IsSink()) {
1227       continue;
1228     }
1229     Node* src_copy = node_map[e->src()->id()];
1230     Node* dst_copy = node_map[e->dst()->id()];
1231     g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1232   }
1233 
1234   // Connect input edges.
1235   //
1236   // We create one Identity node for each input. Then, we connect inputs[i] to
1237   // the i-th identity node added. The nodes that previously connected
1238   // to the j-th output of i-th arg node are reconnected to the i-th
1239   // identity node.
1240   //
1241   // The added identity nodes depend on "input_control_node".
1242   for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
1243     Node* arg = node_map[fbody->arg_nodes[i]->id()];
1244     Node* n = AddIdentity(g, inputs[i]);
1245     if (input_control_node) {
1246       g->AddControlEdge(input_control_node, n);
1247     }
1248     for (const Edge* e : arg->out_edges()) {
1249       if (e->IsControlEdge()) {
1250         g->AddControlEdge(n, e->dst());
1251       } else {
1252         g->AddEdge(n, 0, e->dst(), e->dst_input());
1253       }
1254     }
1255     node_map[fbody->arg_nodes[i]->id()] = n;
1256     g->RemoveNode(arg);  // 'arg' is disconnected.
1257   }
1258 
1259   // Connect output edges.
1260   //
1261   // For i-th return node in fbody->graph, we add in "g" an identity
1262   // node (outputs[i-th]). We then reconnect every incoming edge into
1263   // the i-th return node to the added identity node.
1264   //
1265   // For every data edge coming out of "callee"s i-th output, we
1266   // reconnect it to the i-th identity added above.
1267   //
1268   // If "callee" is control-depended upon by any other nodes, we add a
1269   // NoOp node "output_control_node". "output_control_node" depends on
1270   // all identity nodes added above. And nodes previously depend on
1271   // "callee" is changed to depend on "output_control_node".
1272   std::vector<Node*> outputs(caller->num_outputs());
1273   for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
1274     Node* ret = node_map[fbody->ret_nodes[i]->id()];
1275     Endpoint data;  // Data input for the ret node.
1276     for (const Edge* e : ret->in_edges()) {
1277       if (!e->IsControlEdge()) {
1278         data = {e->src(), e->src_output()};
1279         break;
1280       }
1281     }
1282     CHECK(data.node != nullptr);
1283     Node* n = AddIdentity(g, data);
1284     outputs[i] = n;
1285     for (const Edge* e : ret->in_edges()) {
1286       if (e->IsControlEdge()) {
1287         g->AddControlEdge(e->src(), n);
1288       }
1289     }
1290     g->RemoveNode(ret);  // 'ret' is disconnected.
1291   }
1292   Node* output_control_node = nullptr;
1293   for (const Edge* e : caller->out_edges()) {
1294     if (e->IsControlEdge()) {
1295       if (output_control_node == nullptr) {
1296         output_control_node = AddNoOp(g);
1297         for (Node* n : outputs) {
1298           g->AddControlEdge(n, output_control_node);
1299         }
1300       }
1301       g->AddControlEdge(output_control_node, e->dst());
1302     } else {
1303       g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
1304     }
1305   }
1306   g->RemoveNode(caller);  // 'caller' is replaced with inlined nodes.
1307 }
1308 
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph)1309 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
1310   std::vector<std::pair<Node*, const FunctionBody*>> candidates;
1311   const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
1312   for (Node* node : graph->nodes()) {
1313     VLOG(3) << "Expanding " << node->DebugString();
1314     bool noinline;
1315     if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
1316       VLOG(3) << "noinline: " << node->DebugString();
1317       continue;
1318     }
1319     FunctionLibraryRuntime::Handle handle;
1320     Status s = lib->Instantiate(node->type_string(), node->attrs(), &handle);
1321     if (!s.ok()) {
1322       // Either "node" is a primitive op, or the instantiation failed.
1323       if (errors::IsNotFound(s)) {
1324         VLOG(3) << "ExpandInlineFunctions " << s;
1325       } else {
1326         LOG(ERROR) << "ExpandInlineFunctions " << s;
1327       }
1328       continue;
1329     }
1330     const FunctionBody* fbody = lib->GetFunctionBody(handle);
1331     CHECK_NOTNULL(fbody);
1332     candidates.push_back({node, fbody});
1333   }
1334   for (const auto& p : candidates) {
1335     InlineFunctionBody(*fld, graph, p.first, p.second);
1336   }
1337   return !candidates.empty();
1338 }
1339 
NewName(const Node * n,bool pretty)1340 string NewName(const Node* n, bool pretty) {
1341   if (pretty) {
1342     return strings::StrCat(n->type_string(), n->id());
1343   } else {
1344     return strings::StrCat("n", n->id());
1345   }
1346 }
1347 
1348 // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
1349 // and stash the original NodeDef name as an attr for documentation
1350 // purpose.
ToGraphDef(const Graph * g,GraphDef * gdef,bool pretty)1351 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
1352   // We visit nodes in forward topological sort order, which is a
1353   // possible execution order of the graph.
1354   gtl::InlinedVector<const Edge*, 4> inputs;
1355   gdef->Clear();
1356   gdef->mutable_versions()->CopyFrom(g->versions());
1357 
1358   std::vector<Node*> start_nodes;
1359   for (Node* n : g->nodes()) {
1360     if (n->out_edges().empty()) {
1361       start_nodes.push_back(n);
1362     }
1363   }
1364 
1365   ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
1366     if (!n->IsOp()) return;
1367     NodeDef* ndef = gdef->add_node();
1368     ndef->set_name(NewName(n, pretty));
1369     ndef->set_op(n->type_string());
1370     for (const auto& attr : n->attrs()) {
1371       (*ndef->mutable_attr())[attr.first] = attr.second;
1372     }
1373     inputs.clear();
1374     inputs.resize(n->num_inputs());
1375     for (const Edge* e : n->in_edges()) {
1376       if (e->IsControlEdge()) {
1377         inputs.push_back(e);
1378       } else {
1379         if (inputs[e->dst_input()] == nullptr) {
1380           inputs[e->dst_input()] = e;
1381         } else {
1382           LOG(WARNING) << "Malformed graph node. multiple input edges: "
1383                        << n->DebugString();
1384         }
1385       }
1386     }
1387     // node->name() is merely NodeDef::name, which are not guaranteed
1388     // to be unique and stable after optimization rewrites. Therefore,
1389     // we use "n<node id>" instead.
1390     for (const Edge* e : inputs) {
1391       if (e == nullptr) {
1392         ndef->add_input("unknown");
1393         continue;
1394       }
1395       const string srcname = NewName(e->src(), pretty);
1396       if (!e->src()->IsOp()) {
1397       } else if (e->IsControlEdge()) {
1398         ndef->add_input(strings::StrCat("^", srcname));
1399       } else if (e->src_output() == 0) {
1400         ndef->add_input(srcname);
1401       } else {
1402         ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
1403       }
1404     }
1405   });
1406 }
1407 
DebugString(const Graph * g)1408 string DebugString(const Graph* g) {
1409   GraphDef gdef;
1410   ToGraphDef(g, &gdef);
1411   return DebugString(gdef);
1412 }
1413 
FunctionBody(const FunctionDef & f,DataTypeSlice arg_t,DataTypeSlice ret_t,Graph * g)1414 FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
1415                            DataTypeSlice ret_t, Graph* g)
1416     : fdef(f),
1417       graph(g),
1418       arg_types(arg_t.begin(), arg_t.end()),
1419       ret_types(ret_t.begin(), ret_t.end()) {
1420   this->arg_nodes.resize(arg_types.size());
1421   this->ret_nodes.resize(ret_types.size());
1422   for (Node* n : this->graph->op_nodes()) {
1423     gtl::InlinedVector<Node*, 4>* node_vec;
1424     if (n->type_string() == kRetOp) {
1425       node_vec = &this->ret_nodes;
1426     } else if (n->type_string() == kArgOp) {
1427       node_vec = &this->arg_nodes;
1428     } else {
1429       continue;
1430     }
1431     int index;
1432     TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
1433     CHECK_LE(0, index);
1434     CHECK_LT(index, node_vec->size());
1435     (*node_vec)[index] = n;
1436   }
1437 }
1438 
~FunctionBody()1439 FunctionBody::~FunctionBody() { delete this->graph; }
1440 
1441 class SymbolicGradientHelper {
1442  public:
SymbolicGradientHelper(const FunctionBody & f)1443   explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1444 
~SymbolicGradientHelper()1445   ~SymbolicGradientHelper() { delete gbody_; }
1446 
1447   FunctionBody* Compute();
1448 
1449  private:
1450   const FunctionBody* fbody_;
1451   FunctionBody* gbody_ = nullptr;
1452 
1453   // Makes a copy of fbody_ in gbody_.
1454   void Copy();
1455 
1456   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1457 };
1458 
Copy()1459 void SymbolicGradientHelper::Copy() {
1460   const Graph& src = *(fbody_->graph);
1461   gbody_->graph = new Graph(src.op_registry());
1462   Graph* dst = gbody_->graph;
1463 
1464   std::vector<Node*> node_map(src.num_node_ids());
1465 
1466   // Copy the nodes.
1467   node_map[src.source_node()->id()] = dst->source_node();
1468   node_map[src.sink_node()->id()] = dst->sink_node();
1469   for (Node* n : src.op_nodes()) {
1470     node_map[n->id()] = dst->CopyNode(n);
1471   }
1472 
1473   // Copy the edges.
1474   for (const Edge* e : src.edges()) {
1475     Node* src_copy = node_map[e->src()->id()];
1476     Node* dst_copy = node_map[e->dst()->id()];
1477     dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1478   }
1479 
1480   // Save inputs in copied graph.
1481   CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1482   gbody_->arg_types = fbody_->arg_types;
1483   for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1484     gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1485   }
1486 
1487   // Save outputs in copied graph.
1488   CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1489   gbody_->ret_types = fbody_->ret_types;
1490   for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1491     gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1492   }
1493 }
1494 
Compute()1495 FunctionBody* SymbolicGradientHelper::Compute() {
1496   CHECK(gbody_ == nullptr);
1497   gbody_ = new FunctionBody;
1498 
1499   // Copy fbody_ into gbody_.
1500   Copy();
1501 
1502   Graph* g = gbody_->graph;
1503 
1504   const int num_y = static_cast<int>(gbody_->ret_nodes.size());
1505 
1506   // Populate 'y_node_outputs_' with node function body outputs.
1507   // Populate 'y_grad_nodes' with initial gradient nodes for each return node of
1508   // the original function body (these will be 'arg' nodes in the function
1509   // gradient body).
1510   std::vector<NodeOut> y_node_outputs;
1511   y_node_outputs.reserve(num_y);
1512   std::vector<NodeOut> y_grad_node_outputs;
1513   y_grad_node_outputs.reserve(num_y);
1514   for (int i = 0; i < num_y; ++i) {
1515     Node* y = gbody_->ret_nodes[i];
1516     y_node_outputs.push_back({y, 0});
1517     DCHECK_EQ(y->type_string(), kRetOp);
1518     const DataType dtype = y->input_type(0);
1519     const int index = static_cast<int>(gbody_->arg_nodes.size());
1520     Node* dy = AddArg(g, dtype, index);
1521     gbody_->arg_types.push_back(dtype);
1522     gbody_->arg_nodes.push_back(dy);
1523     y_grad_node_outputs.push_back({dy, 0});
1524   }
1525 
1526   // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1527   const size_t num_x = fbody_->arg_nodes.size();
1528   std::vector<NodeOut> x_node_outputs;
1529   x_node_outputs.reserve(num_x);
1530   for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1531     x_node_outputs.push_back({gbody_->arg_nodes[i], 0});
1532   }
1533 
1534   // Call AddSymbolicGradients which will add nodes to graph 'g' that
1535   // compute the function gradient (adding an entry in 'x_grad_node_outputs' for
1536   // each node in 'x_node_outputs').
1537   std::vector<NodeOut> x_grad_node_outputs;
1538   TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1539                                    y_grad_node_outputs, &x_grad_node_outputs,
1540                                    g));
1541 
1542   // Remove the old return nodes from the function body.
1543   for (Node* n : gbody_->ret_nodes) {
1544     g->RemoveNode(n);
1545   }
1546   gbody_->ret_types = fbody_->arg_types;
1547   gbody_->ret_nodes.clear();
1548   // Add new return nodes to the function gradient body for each node
1549   // in 'x_grad_nodes'.
1550   const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1551   for (int i = 0; i < arg_types_size; ++i) {
1552     Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1553     Node* ret = AddRet(g, grad, i);
1554     gbody_->ret_nodes.push_back(ret);
1555   }
1556 
1557   auto ret = gbody_;
1558   gbody_ = nullptr;
1559   return ret;
1560 }
1561 
SymbolicGradient(const FunctionBody & f)1562 FunctionBody* SymbolicGradient(const FunctionBody& f) {
1563   return SymbolicGradientHelper(f).Compute();
1564 }
1565 
FunctionDefToBodyHelper(const FunctionDef & fdef,const AttrSlice & attrs,const FunctionLibraryDefinition * const lib_def,const std::function<Status (const string &,const OpDef **)> & get_func_sig,FunctionBody ** fbody)1566 Status FunctionDefToBodyHelper(
1567     const FunctionDef& fdef, const AttrSlice& attrs,
1568     const FunctionLibraryDefinition* const lib_def,
1569     const std::function<Status(const string&, const OpDef**)>& get_func_sig,
1570     FunctionBody** fbody) {
1571   // Instantiates the function template into a graph def.
1572   InstantiationResult result;
1573   TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
1574 
1575   std::unique_ptr<Graph> graph(new Graph(lib_def));
1576   GraphConstructorOptions opts;
1577   opts.allow_internal_ops = true;
1578   opts.expect_device_spec = false;
1579   TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
1580 
1581   // Call BuildControlFlowInfo to validate that this function body has
1582   // well-formed control flow.
1583   // NOTE(skyewm): this is usually done in Partition(), but we don't partition
1584   // function bodies. This should be removed if function bodies ever go through
1585   // the Partition() path.
1586   std::vector<ControlFlowInfo> dummy;
1587   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
1588 
1589   *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types,
1590                             graph.release());
1591   return Status::OK();
1592 }
1593 
1594 }  // end namespace tensorflow
1595