• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/function.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/executor_factory.h"
27 #include "tensorflow/core/common_runtime/gradients.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/graph_optimizer.h"
30 #include "tensorflow/core/common_runtime/inline_function_utils.h"
31 #include "tensorflow/core/common_runtime/memory_types.h"
32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/common_runtime/single_threaded_executor.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function_handle_cache.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/op.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/versions.pb.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/control_flow.h"
45 #include "tensorflow/core/graph/node_builder.h"
46 #include "tensorflow/core/graph/optimizer_cse.h"
47 #include "tensorflow/core/lib/core/threadpool.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/str_util.h"
51 #include "tensorflow/core/profiler/lib/connected_traceme.h"
52 #include "tensorflow/core/profiler/lib/traceme.h"
53 #include "tensorflow/core/protobuf/config.pb.h"
54 
55 // See core/kernels/function_ops.cc for related kernels.
56 
57 namespace tensorflow {
58 
59 // A few string constant used throughout this module.
60 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
61 static constexpr const char* const kDeviceArgOp =
62     FunctionLibraryDefinition::kDeviceArgOp;
63 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
64 static constexpr const char* const kDeviceRetOp =
65     FunctionLibraryDefinition::kDeviceRetOp;
66 static constexpr const char* const kGradientOp =
67     FunctionLibraryDefinition::kGradientOp;
68 static constexpr const char* const kNodeLabel = "Func";
69 static constexpr const char* const kFuncAttr =
70     FunctionLibraryDefinition::kFuncAttr;
71 
72 // Represents the index-th output of a node.
73 struct Endpoint {
74   Node* node;
75   int index;
76 
77   // Returns the string name represents this endpoint.
nametensorflow::Endpoint78   string name() const {
79     if (index == 0) {
80       return node->name();
81     } else {
82       return strings::StrCat(node->name(), ":", index);
83     }
84   }
85 
dtypetensorflow::Endpoint86   DataType dtype() const { return node->output_type(index); }
87 };
88 
89 struct EndpointHash {
operator ()tensorflow::EndpointHash90   uint64 operator()(const Endpoint& x) const {
91     return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
92                   x.index);
93   }
94 };
95 
96 struct EndpointEq {
operator ()tensorflow::EndpointEq97   bool operator()(const Endpoint& x, const Endpoint& y) const {
98     return (x.node == y.node) && (x.index == y.index);
99   }
100 };
101 
102 // The following Add* routines are used to add a few graph nodes while
103 // functions are transformed.
AddArg(Graph * g,DataType dtype,int index)104 static Node* AddArg(Graph* g, DataType dtype, int index) {
105   DCHECK_LT(0, dtype);
106   DCHECK_LT(dtype, DT_FLOAT_REF);
107   NodeDef ndef;
108   ndef.set_name(g->NewName(kNodeLabel));
109   ndef.set_op(kArgOp);
110   AddNodeAttr("T", dtype, &ndef);
111   AddNodeAttr("index", index, &ndef);
112   Status s;
113   Node* ret = g->AddNode(ndef, &s);
114   TF_CHECK_OK(s);
115   return ret;
116 }
117 
AddRet(Graph * g,Endpoint input,int index)118 static Node* AddRet(Graph* g, Endpoint input, int index) {
119   DCHECK_LT(0, input.dtype());
120   DCHECK_LT(input.dtype(), DT_FLOAT_REF);
121   NodeDef ndef;
122   ndef.set_name(g->NewName(kNodeLabel));
123   ndef.set_op(kRetOp);
124   ndef.add_input(input.name());
125   AddNodeAttr("T", input.dtype(), &ndef);
126   AddNodeAttr("index", index, &ndef);
127   Status s;
128   Node* ret = g->AddNode(ndef, &s);
129   TF_CHECK_OK(s);
130   g->AddEdge(input.node, input.index, ret, 0);
131   return ret;
132 }
133 
134 // FunctionLibraryRuntime implementation that forwards all the function calls to
135 // the base runtime implementation, and only overrides FunctionLibraryDefinition
136 // in calls to Instantiate (if caller doesn't provide the
137 // InstantiateOptions::lib_def option).
138 //
139 // When the function library runtime (FunctionLibraryRuntimeImpl specifically)
140 // instantiates a function into a Graph object, it also creates an Executor for
141 // it. That executor has a pointer to the function library runtime instance,
142 // that is used to instantiate all nested function calls.
143 //
144 // The function library definition used to instantiate the function must be
145 // preserved in the Executor's function library runtime.
146 //
147 // IMPORTANT: This runtime is intended for use only in executors created for
148 // functions instantiated into a graph in FunctionLibraryRuntimeImpl.
149 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
150  public:
FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime * base_flr,const FunctionLibraryDefinition * lib_def)151   FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr,
152                                 const FunctionLibraryDefinition* lib_def)
153       : base_flr_(base_flr), lib_def_(lib_def) {}
154   ~FunctionLibraryRuntimeOverlay() override;
155 
156   Status Instantiate(const string& function_name, AttrSlice attrs,
157                      const InstantiateOptions& options,
158                      Handle* handle) override;
159 
160   Status ReleaseHandle(Handle handle) override;
161 
162   const FunctionBody* GetFunctionBody(Handle h) override;
163 
164   Status GetRetTypes(Handle h, DataTypeVector* ret_types) override;
165 
166   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
167            std::vector<Tensor>* rets, DoneCallback done) override;
168 
169   void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
170            DoneCallback done) override;
171 
172   Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
173                  std::vector<Tensor>* rets) override;
174 
175   Status RunSync(Options opts, Handle handle,
176                  CallFrameInterface* frame) override;
177 
178   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
179                       OpKernel** kernel) override;
180 
181   bool IsStateful(const string& function_name) const override;
182 
183   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
184       const override;
185 
186   Env* env() override;
187   const ConfigProto* const config_proto() override;
188   Device* device() override;
189   const Device* device() const override;
190   std::function<void(std::function<void()>)>* runner() override;
191   const DeviceMgr* device_mgr() const override;
192 
193   string DebugString(Handle handle) override;
194   int graph_def_version() const override;
195 
196   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
197                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
198                FunctionLibraryRuntime** out_flr,
199                bool skip_flib_def = false) override;
200 
201  private:
202   FunctionLibraryRuntime* base_flr_;          // not owned
203   const FunctionLibraryDefinition* lib_def_;  // not owned
204 };
205 
206 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
207 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)208 Status FunctionLibraryRuntimeOverlay::Instantiate(
209     const string& function_name, AttrSlice attrs,
210     const InstantiateOptions& options, Handle* handle) {
211   // We automatically set the `lib_def` option for all instantiations, if the
212   // caller doesn't set this option explicitly.
213   if (!options.lib_def && lib_def_) {
214     InstantiateOptions options_copy = options;
215     options_copy.lib_def = lib_def_;
216     return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
217   } else {
218     return base_flr_->Instantiate(function_name, attrs, options, handle);
219   }
220 }
221 
ReleaseHandle(Handle handle)222 Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) {
223   return base_flr_->ReleaseHandle(handle);
224 }
225 
GetFunctionBody(Handle h)226 const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) {
227   return base_flr_->GetFunctionBody(h);
228 }
229 
GetRetTypes(Handle h,DataTypeVector * ret_types)230 Status FunctionLibraryRuntimeOverlay::GetRetTypes(Handle h,
231                                                   DataTypeVector* ret_types) {
232   return base_flr_->GetRetTypes(h, ret_types);
233 }
234 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)235 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
236                                         gtl::ArraySlice<Tensor> args,
237                                         std::vector<Tensor>* rets,
238                                         DoneCallback done) {
239   base_flr_->Run(opts, handle, args, rets, std::move(done));
240 }
241 
Run(const Options & opts,Handle handle,CallFrameInterface * call_frame,DoneCallback done)242 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
243                                         CallFrameInterface* call_frame,
244                                         DoneCallback done) {
245   base_flr_->Run(opts, handle, call_frame, std::move(done));
246 }
247 
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)248 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
249                                               gtl::ArraySlice<Tensor> args,
250                                               std::vector<Tensor>* rets) {
251   return base_flr_->RunSync(std::move(opts), handle, args, rets);
252 }
253 
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)254 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
255                                               CallFrameInterface* call_frame) {
256   return base_flr_->RunSync(std::move(opts), handle, call_frame);
257 }
258 
CreateKernel(const std::shared_ptr<const NodeProperties> &,OpKernel **)259 Status FunctionLibraryRuntimeOverlay::CreateKernel(
260     const std::shared_ptr<const NodeProperties>&, OpKernel**) {
261   // We don't have access to base_lib_def_ in base function library runtime (aka
262   // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with
263   // the wrong lib_def we just disable creation of new kernels through overlays.
264   //
265   // When we call Instantiate from the base runtime with the lib_def option,
266   // the base runtime implementation is responsible for correctly passing it
267   // through to all kernel constructions.
268   return errors::Internal(
269       "Overlay function library runtime doesn't support kernel creation.");
270 }
271 
IsStateful(const string & function_name) const272 bool FunctionLibraryRuntimeOverlay::IsStateful(
273     const string& function_name) const {
274   // Important: we do not forward lookup to the base FLR.
275   const OpDef* op_def;
276   const Status s = lib_def_->LookUpOpDef(function_name, &op_def);
277   return s.ok() && op_def->is_stateful();
278 }
279 
env()280 Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
281 
config_proto()282 const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
283   return base_flr_->config_proto();
284 }
285 
device()286 Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
287 
device() const288 const Device* FunctionLibraryRuntimeOverlay::device() const {
289   return base_flr_->device();
290 }
291 
292 std::function<void(std::function<void()>)>*
runner()293 FunctionLibraryRuntimeOverlay::runner() {
294   return base_flr_->runner();
295 }
296 
device_mgr() const297 const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
298   return base_flr_->device_mgr();
299 }
300 
301 const FunctionLibraryDefinition*
GetFunctionLibraryDefinition() const302 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
303   return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition();
304 }
305 
DebugString(Handle handle)306 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
307   return base_flr_->DebugString(handle);
308 }
309 
graph_def_version() const310 int FunctionLibraryRuntimeOverlay::graph_def_version() const {
311   return base_flr_->graph_def_version();
312 }
313 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)314 Status FunctionLibraryRuntimeOverlay::Clone(
315     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
316     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
317     FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
318   // NOTE(ezhulenev): The cloned FunctionLibraryRuntime will be missing the
319   // FunctionLibraryDefinition override, but that's ok because we anyway do not
320   // copy / clone instantiated items from the base FLR.
321   return base_flr_->Clone(out_lib_def, out_pflr, out_flr, skip_flib_def);
322 }
323 
324 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
325  public:
326   FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
327                              const ConfigProto* config, Device* device,
328                              int graph_def_version,
329                              const FunctionLibraryDefinition* lib_def,
330                              thread::ThreadPool* default_thread_pool,
331                              const OptimizerOptions& optimizer_options,
332                              const SessionMetadata* session_metadata,
333                              ProcessFunctionLibraryRuntime* parent);
334 
335   ~FunctionLibraryRuntimeImpl() override;
336 
337   Status Instantiate(const string& function_name, AttrSlice attrs,
338                      const InstantiateOptions& options,
339                      Handle* handle) override;
340 
341   Status ReleaseHandle(Handle handle) override;
342 
343   const FunctionBody* GetFunctionBody(Handle handle) override;
344 
345   Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override;
346 
347   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
348                       OpKernel** kernel) override;
349 
350   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
351            std::vector<Tensor>* rets, DoneCallback done) override;
352   void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
353            DoneCallback done) override;
354   Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
355                  std::vector<Tensor>* rets) override;
356   Status RunSync(Options opts, Handle handle,
357                  CallFrameInterface* call_frame) override;
358 
359   bool IsStateful(const string& function) const override;
360 
GetFunctionLibraryDefinition() const361   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
362       const override {
363     return base_lib_def_;
364   }
365 
device()366   Device* device() override { return device_; }
device() const367   const Device* device() const override { return device_; }
368 
runner()369   std::function<void(std::function<void()>)>* runner() override {
370     return &default_runner_;
371   }
372 
device_mgr() const373   const DeviceMgr* device_mgr() const override { return device_mgr_; }
env()374   Env* env() override { return env_; }
config_proto()375   const ConfigProto* const config_proto() override { return config_; }
graph_def_version() const376   int graph_def_version() const override { return graph_def_version_; }
377 
378   string DebugString(Handle h) override;
379 
380   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
381                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
382                FunctionLibraryRuntime** out_flr,
383                bool skip_flib_def = false) override;
384 
385  private:
386   typedef FunctionLibraryRuntimeImpl ME;
387 
388   const DeviceMgr* const device_mgr_;
389   Device* const device_;
390   Env* const env_;
391   const ConfigProto* const config_;
392   const int graph_def_version_;
393   const FunctionLibraryDefinition* const base_lib_def_;
394   GraphOptimizer optimizer_;
395   const SessionMetadata* const session_metadata_;
396   Executor::Args::Runner default_runner_;
397   const string device_name_;
398 
399   std::function<Status(const string&, const OpDef**)> get_func_sig_;
400   std::function<Status(const std::shared_ptr<const NodeProperties>&,
401                        OpKernel**)>
402       create_kernel_;
403 
404   mutable mutex mu_;
405 
406   int next_handle_ TF_GUARDED_BY(mu_);
407 
408   // The instantiated and transformed function is encoded as a Graph
409   // object, and an executor is created for the graph.
410   struct Item {
411     uint64 instantiation_counter = 0;
412     std::unique_ptr<const Graph> graph = nullptr;
413     const FunctionLibraryDefinition* lib_def = nullptr;  // Not owned.
414     FunctionBody* func_graph = nullptr;
415     Executor* exec = nullptr;
416     FunctionLibraryRuntimeOverlay* overlay_flr = nullptr;
417     string executor_type;
418 
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item419     ~Item() {
420       delete this->func_graph;
421       delete this->exec;
422       delete this->overlay_flr;
423     }
424   };
425   std::unique_ptr<absl::flat_hash_map<Handle, std::unique_ptr<Item>>> items_
426       TF_GUARDED_BY(mu_);
427   std::unique_ptr<FunctionHandleCache> function_handle_cache_;
428   ProcessFunctionLibraryRuntime* parent_ = nullptr;  // not owned.
429 
430   // Overloads the CreateKernel method, providing a FunctionLibraryRuntime
431   // to use for kernel creation and execution. In particular, this method can
432   // accept a FunctionLibraryRuntimeOverlay that overlays a different
433   // FunctionLibraryDefinition.
434   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
435                       FunctionLibraryRuntime* flr, OpKernel** kernel);
436   Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
437                            const FunctionLibraryDefinition* lib_def,
438                            std::unique_ptr<FunctionBody>* fbody);
439   Status CreateItem(Item** item);
440   Status GetOrCreateItem(LocalHandle local_handle, Item** item);
441   Status InstantiateSymbolicGradient(const NameAttrList& func,
442                                      const FunctionLibraryDefinition* lib_def,
443                                      std::unique_ptr<FunctionBody>* g_body);
444   bool IsLocalTarget(const InstantiateOptions& options) const;
445   AttrValueMap FixAttrs(const AttrSlice& attrs);
446   void RunRemote(const Options& opts, Handle handle,
447                  gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
448                  Item* item, DoneCallback done);
449 
450   Status PrepareRunSync(
451       Handle handle, Options* run_opts, Item** out_item,
452       std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous);
453 
454   void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
455                                CallFrameInterface* frame,
456                                Executor::Args* exec_args);
457 
458   TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
459 };
460 
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * default_thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)461 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
462     const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
463     int graph_def_version, const FunctionLibraryDefinition* lib_def,
464     thread::ThreadPool* default_thread_pool,
465     const OptimizerOptions& optimizer_options,
466     const SessionMetadata* session_metadata,
467     ProcessFunctionLibraryRuntime* parent)
468     : device_mgr_(dmgr),
469       device_(device),
470       env_(env),
471       config_(config),
472       graph_def_version_(graph_def_version),
473       base_lib_def_(lib_def),
474       optimizer_(optimizer_options),
475       session_metadata_(session_metadata),
476       default_runner_(nullptr),
477       device_name_(device_ == nullptr
478                        ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
479                        : device_->name()),
480       next_handle_(0),
481       items_(absl::make_unique<
482              absl::flat_hash_map<Handle, std::unique_ptr<Item>>>()),
483       function_handle_cache_(absl::make_unique<FunctionHandleCache>(this)),
484       parent_(parent) {
485   get_func_sig_ = [this](const string& op, const OpDef** sig) {
486     return base_lib_def_->LookUpOpDef(op, sig);
487   };
488   create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props,
489                           OpKernel** kernel) {
490     return CreateKernel(props, kernel);
491   };
492   thread::ThreadPool* pool = nullptr;
493   if (device_ != nullptr) {
494     pool = device_->tensorflow_device_thread_pool();
495   }
496   if (pool == nullptr) {
497     pool = default_thread_pool;
498   }
499   if (pool != nullptr) {
500     default_runner_ = [pool](Executor::Args::Closure c) {
501       pool->Schedule(std::move(c));
502     };
503   }
504 }
505 
~FunctionLibraryRuntimeImpl()506 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
507   // Deleting the items_ list will delete all the function handles registered in
508   // this object. A function may contains a few sub-functions which have also
509   // been registered in this object. Deleting the parent function will call
510   // ReleaseHandle in this class again for each of the sub-functions. These
511   // circular calls may cause segfault since the items_ may have already been
512   // partially deleted when releasing handles of sub-functions. Explicitly
513   // release items_ here and check it in ReleaseHandle to avoid this.
514   items_.reset();
515 }
516 
517 // An asynchronous op kernel which executes an instantiated function
518 // defined in a library.
519 class CallOp : public AsyncOpKernel {
520  public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)521   CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
522       : AsyncOpKernel(ctx), handle_(handle) {}
523 
~CallOp()524   ~CallOp() override {
525     // TODO(iga): Release the cached handle_
526   }
527 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)528   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
529     FunctionLibraryRuntime* lib = ctx->function_library();
530     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
531                       errors::Internal("No function library is provided."),
532                       done);
533     FunctionLibraryRuntime::Options opts;
534     opts.rendezvous = ctx->rendezvous();
535     opts.cancellation_manager = ctx->cancellation_manager();
536     opts.step_container = ctx->step_container();
537     opts.stats_collector = ctx->stats_collector();
538     opts.runner = ctx->runner();
539     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
540     opts.collective_executor = ctx->collective_executor();
541     std::vector<Tensor> args;
542     args.reserve(ctx->num_inputs());
543     for (int i = 0; i < ctx->num_inputs(); ++i) {
544       args.push_back(ctx->input(i));
545     }
546     std::vector<Tensor>* rets = new std::vector<Tensor>;
547     profiler::TraceMe trace_me(
548         [&] {
549           return absl::StrCat("CallOp #parent_step_id=", ctx->step_id(),
550                               ",function_step_id=", opts.step_id, "#");
551         },
552         /*level=*/2);
553     lib->Run(opts, handle_, args, rets,
554              [ctx, done, rets](const Status& status) {
555                if (!status.ok()) {
556                  ctx->SetStatus(status);
557                } else {
558                  const int ret_size = static_cast<int>(rets->size());
559                  CHECK_EQ(ret_size, ctx->num_outputs());
560                  for (int i = 0; i < ret_size; ++i) {
561                    ctx->set_output(i, (*rets)[i]);
562                  }
563                }
564                delete rets;
565                done();
566              });
567   }
568 
569  private:
570   FunctionLibraryRuntime::Handle handle_;
571 
572   TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
573 };
574 
GetFunctionBody(Handle h)575 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
576   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
577   if (local_handle == kInvalidLocalHandle) {
578     LOG(ERROR) << "Could not find Handle: " << h
579                << " on device: " << device_name_;
580     return nullptr;
581   }
582 
583   tf_shared_lock l(mu_);
584   auto iter = items_->find(local_handle);
585   CHECK(iter != items_->end());
586   return iter->second->func_graph;
587 }
588 
GetRetTypes(Handle h,DataTypeVector * ret_types)589 Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h,
590                                                DataTypeVector* ret_types) {
591   if (parent_->IsMultiDevice(h)) {
592     return parent_->GetRetTypes(h, ret_types);
593   }
594   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
595   if (local_handle == kInvalidLocalHandle) {
596     return errors::InvalidArgument("Handle ", h, " not found.");
597   }
598   const FunctionBody* fbody = GetFunctionBody(h);
599   *ret_types = fbody->ret_types;
600   return Status::OK();
601 }
602 
CreateKernel(const std::shared_ptr<const NodeProperties> & props,OpKernel ** kernel)603 Status FunctionLibraryRuntimeImpl::CreateKernel(
604     const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) {
605   return CreateKernel(props, this, kernel);
606 }
607 
CreateKernel(const std::shared_ptr<const NodeProperties> & props,FunctionLibraryRuntime * flr,OpKernel ** kernel)608 Status FunctionLibraryRuntimeImpl::CreateKernel(
609     const std::shared_ptr<const NodeProperties>& props,
610     FunctionLibraryRuntime* flr, OpKernel** kernel) {
611   // If a custom kernel creator is given, try that.
612   Status s;
613   const CustomKernelCreator* custom_kernel_creator =
614       GetDefaultCustomKernelCreator();
615   if (custom_kernel_creator &&
616       custom_kernel_creator->CanCreateKernel(*flr, props)) {
617     std::unique_ptr<OpKernel> ret;
618     s = custom_kernel_creator->CreateKernel(flr, props, &ret);
619     if (s.ok()) {
620       *kernel = ret.release();
621     } else {
622       VLOG(2) << "Custom creator error: " << s;
623     }
624     return s;
625   }
626 
627   const FunctionLibraryDefinition* lib_def =
628       flr->GetFunctionLibraryDefinition();
629   if (lib_def->Find(props->node_def.op()) == nullptr) {
630     // A primitive operation. Creates the registered kernel.
631     return CreateNonCachedKernel(device_, flr, props, graph_def_version_,
632                                  kernel);
633   }
634 
635   // Try to instantiate this function for the func/attr. Maybe it's
636   // cached already.
637   InstantiateOptions options;
638   if (lib_def != base_lib_def_) {
639     options.lib_def = lib_def;
640   }
641   Handle handle;
642   TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(),
643                                  AttrSlice(&props->node_def.attr()), options,
644                                  &handle));
645 
646   const FunctionBody* fbody = GetFunctionBody(handle);
647   CHECK_NOTNULL(fbody);
648 
649   // TODO(zhifengc): For now, we assume int32 and resources are always on host
650   // memory and other types are always on device memory. We should do type
651   // inference over function body to derive the correct input/output memory
652   // types.
653   MemoryTypeVector input_memory_types;
654   for (const auto& t : fbody->arg_types) {
655     input_memory_types.push_back(MTypeFromDType(t));
656   }
657   MemoryTypeVector output_memory_types;
658   for (const auto& t : fbody->ret_types) {
659     output_memory_types.push_back(MTypeFromDType(t));
660   }
661 
662   // Constructs a CallOp kernel for running the instantiated function.
663   auto device_type = DeviceType(device_->attributes().device_type());
664   auto new_props = std::make_shared<NodeProperties>(
665       &fbody->fdef.signature(), props->node_def, fbody->arg_types,
666       fbody->ret_types);
667   OpKernelConstruction construction(
668       device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
669       device_->resource_manager(), props, input_memory_types,
670       output_memory_types, graph_def_version_, &s);
671   if (s.ok()) {
672     *kernel = new CallOp(handle, &construction);
673   }
674   return s;
675 }
676 
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * fbody)677 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
678     const FunctionDef& fdef, AttrSlice attrs,
679     const FunctionLibraryDefinition* lib_def,
680     std::unique_ptr<FunctionBody>* fbody) {
681   if (lib_def == base_lib_def_) {
682     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
683   } else {
684     auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
685       return lib_def->LookUpOpDef(op, sig);
686     };
687     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
688   }
689 }
690 
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * g_body)691 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
692     const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
693     std::unique_ptr<FunctionBody>* g_body) {
694   const FunctionDef* fdef = lib_def->Find(func.name());
695   if (fdef == nullptr) {
696     // f is a primitive op.
697     gradient::Creator creator;
698     TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
699     if (creator == nullptr) {
700       return errors::InvalidArgument("No gradient is defined for ",
701                                      func.name());
702     }
703     FunctionDef grad_fdef;
704     // TODO(josh11b): Should filter out the attrs from func that aren't used
705     // by the gradient function.
706     TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
707     TF_RETURN_IF_ERROR(
708         FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
709   } else {
710     // f is a user-defined function.
711     InstantiateOptions options;
712     if (lib_def != base_lib_def_) {
713       options.lib_def = lib_def;
714     }
715     Handle f_handle;
716     TF_RETURN_IF_ERROR(
717         Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
718     const FunctionBody* f_body = GetFunctionBody(f_handle);
719     CHECK_NOTNULL(f_body);
720     *g_body = SymbolicGradient(*f_body);
721   }
722   return Status::OK();
723 }
724 
IsLocalTarget(const InstantiateOptions & options) const725 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
726     const InstantiateOptions& options) const {
727   if (device_ == nullptr) return true;
728   if (options.target.empty()) return true;
729   if (options.is_multi_device_function) return false;
730   Device* target_device;
731   if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
732     VLOG(1) << "Not instantiating function in FLR because failed to "
733             << "find device " << options.target << " in device manager";
734     return false;
735   }
736   if (target_device != device_) {
737     VLOG(1) << "Not instantiating function in FLR because target device "
738             << options.target
739             << " is different from FLR's device: " << device_->DebugString();
740     return false;
741   }
742   return true;
743 }
744 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)745 Status FunctionLibraryRuntimeImpl::Instantiate(
746     const string& function_name, AttrSlice attrs,
747     const InstantiateOptions& options, Handle* handle) {
748   if (!IsLocalTarget(options)) {
749     return parent_->Instantiate(function_name, attrs, options, handle);
750   }
751 
752   if (options.use_function_cache) {
753     InstantiateOptions options_copy(options);
754     options_copy.use_function_cache = false;
755     return function_handle_cache_->Instantiate(function_name, attrs,
756                                                options_copy, handle);
757   }
758 
759   // Since this is a local target, ensure that the local `device_name_` appears
760   // in the canonical key.
761   InstantiateOptions options_copy(options);
762   options_copy.target = device_name_;
763   const string key = Canonicalize(function_name, attrs, options_copy);
764 
765   {
766     mutex_lock l(mu_);
767     *handle = parent_->GetHandle(key);
768     if (*handle != kInvalidHandle) {
769       FunctionLibraryRuntime::LocalHandle handle_on_device =
770           parent_->GetHandleOnDevice(device_name_, *handle);
771       if (handle_on_device == kInvalidLocalHandle) {
772         return errors::Internal("LocalHandle not found for handle ", *handle,
773                                 ".");
774       }
775       auto item_handle = items_->find(handle_on_device);
776       if (item_handle == items_->end()) {
777         return errors::Internal("LocalHandle ", handle_on_device,
778                                 " for handle ", *handle,
779                                 " not found in items.");
780       }
781       ++item_handle->second->instantiation_counter;
782       return Status::OK();
783     }
784   }
785 
786   const FunctionLibraryDefinition* lib_def =
787       options.lib_def ? options.lib_def : base_lib_def_;
788   std::unique_ptr<FunctionBody> fbody;
789   if (function_name == kGradientOp) {
790     const AttrValue* f = attrs.Find(kFuncAttr);
791     if (f == nullptr) {
792       return errors::InvalidArgument("SymbolicGradient is missing attr: f");
793     }
794     const auto& func = f->func();
795     if (func.name() == kGradientOp) {
796       return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
797     }
798     const string grad = lib_def->FindGradient(func.name());
799     if (!grad.empty()) {
800       return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
801     }
802     TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
803   } else {
804     const FunctionDef* fdef = lib_def->Find(function_name);
805     if (fdef == nullptr) {
806       return errors::NotFound("Function ", function_name, " is not defined.");
807     }
808     TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
809   }
810 
811   LocalHandle local_handle;
812   {
813     mutex_lock l(mu_);
814     *handle = parent_->GetHandle(key);
815     if (*handle != kInvalidHandle) {
816       local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
817       ++(*items_)[local_handle]->instantiation_counter;
818     } else {
819       *handle = parent_->AddHandle(key, device_name_, next_handle_);
820       Item* item = new Item;
821       item->func_graph = fbody.release();
822       item->instantiation_counter = 1;
823       item->executor_type = ExecutorType(options, attrs);
824       if (options.lib_def) {
825         item->overlay_flr =
826             new FunctionLibraryRuntimeOverlay(this, options.lib_def);
827       }
828       local_handle = next_handle_++;
829       items_->emplace(local_handle, std::unique_ptr<Item>(item));
830     }
831   }
832 
833   if (options.create_kernels_eagerly) {
834     Item* item;
835     TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
836   }
837 
838   return Status::OK();
839 }
840 
ReleaseHandle(Handle handle)841 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
842   LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
843   if (h == kInvalidLocalHandle) {
844     return parent_->ReleaseHandle(handle);
845   }
846   std::unique_ptr<Item> item_to_delete;
847   Status parent_status;
848   {
849     mutex_lock l(mu_);
850     // Return directly if all items has already been released.
851     if (items_ == nullptr) return Status::OK();
852 
853     auto it = items_->find(h);
854     if (it == items_->end()) {
855       return errors::Internal(
856           "Inconsistent FunctionLibraryRuntime. Expected to find an item for "
857           "handle ",
858           h, " but found none");
859     }
860     std::unique_ptr<Item>& item = it->second;
861     --item->instantiation_counter;
862     if (item->instantiation_counter == 0) {
863       // We don't simply erase h's item because that would trigger
864       // item destruction while holding mu_. Item destruction can
865       // trigger graph destruction. If the graph contains kernels like
866       // CallOp or PartitionCallOp, their destructors will release cached
867       // function handles, resulting in deadlock here.
868       item_to_delete = std::move(item);
869       items_->erase(h);
870       parent_status = parent_->RemoveHandle(handle);
871     }
872   }
873   return parent_status;
874 }
875 
876 namespace {
877 
878 // Removes all stateless nodes that do not contribute to a return
879 // value from the function body. Unlike `RemoveDeadNodes()`, which is
880 // triggered by `OptimizerOptions.do_function_inlining`, this pass
881 // ignores the SINK node, from which (by definition) all nodes are
882 // reverse reachable, and preserves all nodes that are reachable from
883 // control output nodes.
884 //
885 // TODO(ezhulenev, skyewm): Function body should not have special treatment of
886 // stateful ops, graph should encode nodes that must execute with `control_ret`
887 // and `control_output`.
PruneFunctionBody(const FunctionDef & fdef,Graph * g)888 void PruneFunctionBody(const FunctionDef& fdef, Graph* g) {
889   VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name();
890 
891   // `control_ret` nodes must be always executed.
892   std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes;
893   for (const auto& control_ret : fdef.control_ret()) {
894     control_ret_nodes.insert(control_ret.second);
895   }
896 
897   std::unordered_set<const Node*> nodes;
898   for (auto n : g->nodes()) {
899     // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
900     // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
901     // specifically exclude them as seeds, to avoid unconditionally executing
902     // unused argument nodes (e.g. in a function like `lambda x, y: y`).
903     // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
904     // still needed. It would be preferable to prune entire loops and/or
905     // conditionals if they are not used in the graph.
906     if (n->IsControlFlow() ||
907         (n->op_def().is_stateful() && n->type_string() != kArgOp) ||
908         (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) {
909       nodes.insert(n);
910     }
911   }
912   bool changed = PruneForReverseReachability(g, std::move(nodes));
913   if (changed) {
914     FixupSourceAndSinkEdges(g);
915   }
916 }
917 
918 constexpr int kMaxNodesForSingleThreadedExecutor = 32;
919 
920 // Returns true if the given operation is suitable to execute via
921 // SingleThreadedExecutor. This is an intentional subset of the ops which
922 // technically can be run via single-threaded execution to avoid issues with
923 // recursion or function invocation.
924 //
925 // SingleThreadedExecutor runs asynchronous kernels synchronously: this can lead
926 // to deadlocks. This function attempts to exclude all async kernels in lieu of
927 // kernel instantiation.
IsOpSingleThreadedExecutorCompatible(const Node & n)928 bool IsOpSingleThreadedExecutorCompatible(const Node& n) {
929   if (n.IsFunctionCall() || n.IsPartitionedCall() || n.IsIfNode() ||
930       n.IsWhileNode() || n.IsCaseNode()) {
931     return false;
932   }
933   if (n.IsControlFlow()) {
934     return false;
935   }
936   if (n.IsSend() || n.IsHostSend() || n.IsRecv() || n.IsHostRecv()) {
937     return false;
938   }
939   if (n.IsCollective()) {
940     return false;
941   }
942   for (DataType dt : n.output_types()) {
943     if (IsRefType(dt)) {
944       return false;
945     }
946   }
947   std::string lower = str_util::Lowercase(n.op_def().name());
948   if (str_util::StrContains(lower, "pyfunc") ||
949       str_util::StrContains(lower, "queue") ||
950       str_util::StrContains(lower, "rpc")) {
951     return false;
952   }
953 
954   return true;
955 }
956 
957 // Returns true if the given Graph is safe & efficient to run via the single
958 // threaded executor. The single-threaded executor has lower dispatch overhead
959 // for simple functions.
960 //
961 // This currently specializes for the case of a single operation, as created
962 // via eager execution.
IsSingleThreadedExecutorCompatible(const Graph * g)963 bool IsSingleThreadedExecutorCompatible(const Graph* g) {
964   // TODO(b/187729969): Temporarily disabled due to b/187306798.
965   return false;
966 
967   // Not worth analyzing large graphs.
968   if (g->num_nodes() > kMaxNodesForSingleThreadedExecutor) {
969     return false;
970   }
971 
972   int count = 0;
973   for (Node* n : g->nodes()) {
974     if (!IsOpSingleThreadedExecutorCompatible(*n)) {
975       return false;
976     }
977     if (n->op_def().name() == "_Arg" || n->op_def().name() == "_Retval" ||
978         n->op_def().name() == "NoOp") {
979       continue;
980     }
981 
982     count += 1;
983   }
984 
985   if (count == 1) {
986     return true;
987   }
988 
989   return false;
990 }
991 
992 }  // namespace
993 
CreateItem(Item ** item)994 Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
995   const FunctionBody* fbody;
996   FunctionLibraryRuntime* flr;
997   string executor_type;
998   {
999     tf_shared_lock l(mu_);
1000     fbody = (*item)->func_graph;
1001     flr = (*item)->overlay_flr
1002               ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr)
1003               : static_cast<FunctionLibraryRuntime*>(this);
1004     executor_type = (*item)->executor_type;
1005   }
1006   const FunctionLibraryDefinition* lib_def =
1007       flr->GetFunctionLibraryDefinition();
1008   auto g = absl::make_unique<Graph>(lib_def);
1009   CopyGraph(*fbody->graph, g.get());
1010 
1011   PruneFunctionBody(fbody->fdef, g.get());
1012   optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
1013   TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
1014                                        device()->name(), g.get()));
1015 
1016   // Creates an executor based on the g. This must be done without
1017   // holding mu_ because create_kernel_ calls back into the library.
1018   LocalExecutorParams params;
1019   params.device = device_;
1020   params.function_library = flr;
1021   if (flr == this) {
1022     params.create_kernel = create_kernel_;
1023   } else {
1024     params.create_kernel =
1025         [this, flr](const std::shared_ptr<const NodeProperties>& props,
1026                     OpKernel** kernel) {
1027           return CreateKernel(props, flr, kernel);
1028         };
1029   }
1030   params.delete_kernel = [](OpKernel* kernel) {
1031     DeleteNonCachedKernel(kernel);
1032   };
1033   params.session_metadata = session_metadata_;
1034   std::unique_ptr<Executor> exec;
1035 
1036   if (executor_type.empty() && IsSingleThreadedExecutorCompatible(g.get())) {
1037     executor_type = "SINGLE_THREADED_EXECUTOR";
1038   }
1039   TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
1040   {
1041     // Guard item since it is already inserted in items_.
1042     mutex_lock l(mu_);
1043     if ((*item)->exec == nullptr) {
1044       (*item)->graph = std::move(g);
1045       (*item)->exec = exec.release();
1046     }
1047   }
1048   return Status::OK();
1049 }
1050 
GetOrCreateItem(LocalHandle local_handle,Item ** item)1051 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
1052                                                    Item** item) {
1053   {
1054     tf_shared_lock l(mu_);
1055     auto iter = items_->find(local_handle);
1056     if (iter == items_->end()) {
1057       return errors::Internal("Local function handle ", local_handle,
1058                               " is not valid. Likely an internal error.");
1059     }
1060     *item = iter->second.get();
1061     if ((*item)->exec != nullptr) {
1062       return Status::OK();
1063     }
1064   }
1065   // NOTE: We need to call CreateItem out of mu_ because creating an
1066   // executor needs to call CreateKernel.
1067   return CreateItem(item);
1068 }
1069 
ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options & run_opts,CallFrameInterface * frame,Executor::Args * exec_args)1070 void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
1071     const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
1072     Executor::Args* exec_args) {
1073   // Inherit the step_id from the caller.
1074   exec_args->step_id = run_opts.step_id;
1075   exec_args->rendezvous = run_opts.rendezvous;
1076   exec_args->stats_collector = run_opts.stats_collector;
1077   exec_args->cancellation_manager = run_opts.cancellation_manager;
1078   exec_args->step_container = run_opts.step_container;
1079   if (run_opts.runner) {
1080     exec_args->runner = *run_opts.runner;
1081   } else {
1082     exec_args->runner = default_runner_;
1083   }
1084   exec_args->collective_executor = run_opts.collective_executor;
1085   exec_args->call_frame = frame;
1086   exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline;
1087   exec_args->user_intra_op_threadpool = run_opts.user_intra_op_threadpool;
1088   exec_args->coordination_service_agent = run_opts.coordination_service_agent;
1089 }
1090 
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Item * item,DoneCallback done)1091 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
1092                                            gtl::ArraySlice<Tensor> args,
1093                                            std::vector<Tensor>* rets,
1094                                            Item* item, DoneCallback done) {
1095   string target_device = parent_->GetDeviceName(handle);
1096   string source_device = opts.source_device;
1097   RendezvousInterface* rendezvous = opts.rendezvous;
1098   DeviceContext* device_context;
1099   Status s = parent_->GetDeviceContext(target_device, &device_context);
1100   if (!s.ok()) {
1101     done(s);
1102     return;
1103   }
1104   int64_t src_incarnation, target_incarnation;
1105   s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
1106   s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
1107   if (!s.ok()) {
1108     done(s);
1109     return;
1110   }
1111 
1112   const FunctionBody* fbody = GetFunctionBody(handle);
1113   FunctionCallFrame* frame =
1114       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1115   Executor::Args* exec_args = new Executor::Args;
1116   ExecutorArgsFromOptions(opts, frame, exec_args);
1117 
1118   std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
1119   args_alloc_attrs.reserve(fbody->arg_types.size());
1120   rets_alloc_attrs.reserve(fbody->ret_types.size());
1121   // Note: Functions assume that int32's are always on host memory.
1122   for (const auto& arg_type : fbody->arg_types) {
1123     AllocatorAttributes arg_alloc_attrs;
1124     if (MTypeFromDType(arg_type) == HOST_MEMORY) {
1125       arg_alloc_attrs.set_on_host(true);
1126     }
1127     args_alloc_attrs.push_back(arg_alloc_attrs);
1128   }
1129   for (const auto& ret_type : fbody->ret_types) {
1130     AllocatorAttributes ret_alloc_attrs;
1131     if (MTypeFromDType(ret_type) == HOST_MEMORY) {
1132       ret_alloc_attrs.set_on_host(true);
1133     }
1134     rets_alloc_attrs.push_back(ret_alloc_attrs);
1135   }
1136 
1137   bool allow_dead_tensors = opts.allow_dead_tensors;
1138 
1139   // The ProcFLR sends the arguments to the function from the source_device to
1140   // the target_device. So here we receive those arguments. Similarly, when the
1141   // computation is done and stored in *rets, we send the return values back
1142   // to the source_device (caller) so that the ProcFLR can receive them later.
1143   std::vector<Tensor>* remote_args = new std::vector<Tensor>;
1144   ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
1145       source_device, target_device, "arg_", src_incarnation, args.size(),
1146       device_context, args_alloc_attrs, rendezvous, remote_args,
1147       [frame, remote_args, item, source_device, target_device,
1148        target_incarnation, rendezvous, device_context, rets, done, exec_args,
1149        rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1150         Status s = status;
1151         if (s.ok()) {
1152           s = frame->SetArgs(*remote_args);
1153         }
1154         if (!s.ok()) {
1155           delete frame;
1156           delete remote_args;
1157           delete exec_args;
1158           done(s);
1159           return;
1160         }
1161         item->exec->RunAsync(
1162             *exec_args,
1163             [frame, rets, done, source_device, target_device,
1164              target_incarnation, rendezvous, device_context, remote_args,
1165              rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1166               Status s = status;
1167               if (s.ok()) {
1168                 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1169               }
1170               delete frame;
1171               if (!s.ok()) {
1172                 delete remote_args;
1173                 done(s);
1174                 return;
1175               }
1176               s = ProcessFunctionLibraryRuntime::SendTensors(
1177                   target_device, source_device, "ret_", target_incarnation,
1178                   *rets, device_context, rets_alloc_attrs, rendezvous);
1179               delete remote_args;
1180               done(s);
1181             });
1182         delete exec_args;
1183       });
1184 }
1185 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)1186 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1187                                      gtl::ArraySlice<Tensor> args,
1188                                      std::vector<Tensor>* rets,
1189                                      DoneCallback done) {
1190   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1191     done(errors::Cancelled("Function was cancelled before it was started"));
1192     return;
1193   }
1194   Options run_opts = opts;
1195   if (opts.create_rendezvous) {
1196     auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
1197     run_opts.rendezvous = rendezvous;
1198     run_opts.create_rendezvous = false;
1199     done = [done = std::move(done), rendezvous](const Status& status) mutable {
1200       delete rendezvous;
1201       done(status);
1202     };
1203   }
1204 
1205   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1206   if (local_handle == kInvalidLocalHandle) {
1207     parent_->Run(run_opts, handle, args, rets, done);
1208     return;
1209   }
1210 
1211   if (run_opts.runner == nullptr) {
1212     run_opts.runner = &default_runner_;
1213   }
1214   DCHECK(run_opts.runner != nullptr);
1215 
1216   Item* item = nullptr;
1217   Status s = GetOrCreateItem(local_handle, &item);
1218   if (!s.ok()) {
1219     done(s);
1220     return;
1221   }
1222 
1223   if (run_opts.remote_execution) {
1224     // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
1225     RunRemote(run_opts, handle, args, rets, item, std::move(done));
1226     return;
1227   }
1228 
1229   const FunctionBody* fbody = GetFunctionBody(handle);
1230   FunctionCallFrame* frame =
1231       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1232   s = frame->SetArgs(args);
1233   if (!s.ok()) {
1234     delete frame;
1235     done(s);
1236     return;
1237   }
1238 
1239   profiler::TraceMeProducer activity(
1240       // To TraceMeConsumers in ExecutorState::Process/Finish.
1241       [&opts] {
1242         return profiler::TraceMeEncode("FunctionRun",
1243                                        {{"id", opts.step_id}, {"_r", 1}});
1244       },
1245       profiler::ContextType::kTfExecutor, opts.step_id,
1246       profiler::TraceMeLevel::kInfo);
1247 
1248   Executor::Args exec_args;
1249   ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1250 
1251   bool allow_dead_tensors = run_opts.allow_dead_tensors;
1252   item->exec->RunAsync(
1253       // Executor args
1254       exec_args,
1255       // Done callback.
1256       [frame, rets, done, allow_dead_tensors](const Status& status) {
1257         Status s = status;
1258         if (s.ok()) {
1259           s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1260         }
1261         delete frame;
1262         done(s);
1263       });
1264 }
1265 
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)1266 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1267                                      CallFrameInterface* frame,
1268                                      DoneCallback done) {
1269   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1270     done(errors::Cancelled(""));
1271     return;
1272   }
1273 
1274   Options run_opts = opts;
1275   if (opts.create_rendezvous) {
1276     auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
1277     run_opts.rendezvous = rendezvous;
1278     run_opts.create_rendezvous = false;
1279     done = [done = std::move(done), rendezvous](const Status& status) mutable {
1280       delete rendezvous;
1281       done(status);
1282     };
1283   }
1284 
1285   LocalHandle local_handle = parent_->GetHandleOnDevice(
1286       device_name_, handle, /*include_multi_device=*/true);
1287   if (local_handle == kInvalidLocalHandle) {
1288     parent_->Run(run_opts, handle, frame, done);
1289     return;
1290   }
1291 
1292   if (opts.remote_execution) {
1293     // NOTE(mrry): This bit is only set for a local function when `parent_`
1294     // calls back into this class, and the current implementation of
1295     // `ProcessFunctionLibraryRuntime` currently always uses the vector-based
1296     // `args`/`rets` interface.
1297     done(errors::Unimplemented("Remote calling with CallFrameInterface"));
1298     return;
1299   }
1300 
1301   Item* item = nullptr;
1302   Status s = GetOrCreateItem(local_handle, &item);
1303   if (!s.ok()) {
1304     done(s);
1305     return;
1306   }
1307   if (run_opts.runner == nullptr) {
1308     run_opts.runner = &default_runner_;
1309   }
1310   DCHECK(run_opts.runner != nullptr);
1311 
1312   profiler::TraceMeProducer activity(
1313       // To TraceMeConsumers in ExecutorState::Process/Finish.
1314       [&opts] {
1315         return profiler::TraceMeEncode("FunctionRun",
1316                                        {{"id", opts.step_id}, {"_r", 1}});
1317       },
1318       profiler::ContextType::kTfExecutor, opts.step_id,
1319       profiler::TraceMeLevel::kInfo);
1320 
1321   Executor::Args exec_args;
1322   ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1323   item->exec->RunAsync(exec_args, std::move(done));
1324 }
1325 
PrepareRunSync(Handle handle,Options * run_opts,Item ** out_item,std::unique_ptr<PrivateIntraProcessRendezvous> * out_rendezvous)1326 Status FunctionLibraryRuntimeImpl::PrepareRunSync(
1327     Handle handle, Options* run_opts, Item** out_item,
1328     std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous) {
1329   if (run_opts->cancellation_manager &&
1330       run_opts->cancellation_manager->IsCancelled()) {
1331     return errors::Cancelled("");
1332   }
1333 
1334   if (run_opts->remote_execution) {
1335     // NOTE(mrry): This bit is only set for a local function when `parent_`
1336     // calls back into this class, and the current implementation of
1337     // `ProcessFunctionLibraryRuntime` currently always uses the asynchronous
1338     // Run() method.
1339     return errors::Unimplemented("Remote calling with RunSync()");
1340   }
1341 
1342   if (run_opts->create_rendezvous) {
1343     *out_rendezvous =
1344         absl::make_unique<PrivateIntraProcessRendezvous>(device_mgr_);
1345     run_opts->rendezvous = out_rendezvous->get();
1346     run_opts->create_rendezvous = false;
1347   }
1348 
1349   LocalHandle local_handle = parent_->GetHandleOnDevice(
1350       device_name_, handle, /*include_multi_device=*/true);
1351   if (local_handle == kInvalidLocalHandle) {
1352     *out_item = nullptr;
1353     return Status::OK();
1354   }
1355 
1356   TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item));
1357 
1358   if (run_opts->runner == nullptr) {
1359     run_opts->runner = &default_runner_;
1360   }
1361   DCHECK(run_opts->runner != nullptr);
1362 
1363   return Status::OK();
1364 }
1365 
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)1366 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1367                                            gtl::ArraySlice<Tensor> args,
1368                                            std::vector<Tensor>* rets) {
1369   Item* item = nullptr;
1370   std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1371   TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1372   if (item == nullptr) {
1373     return parent_->RunSync(opts, handle, args, rets);
1374   }
1375 
1376   Executor::Args exec_args;
1377   const FunctionBody* fbody = GetFunctionBody(handle);
1378   FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
1379   TF_RETURN_IF_ERROR(frame.SetArgs(args));
1380   ExecutorArgsFromOptions(opts, &frame, &exec_args);
1381 
1382   TF_RETURN_IF_ERROR(item->exec->Run(exec_args));
1383   return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
1384 }
1385 
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)1386 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1387                                            CallFrameInterface* call_frame) {
1388   Item* item = nullptr;
1389   std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1390   TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1391   if (item == nullptr) {
1392     return parent_->RunSync(opts, handle, call_frame);
1393   }
1394 
1395   Executor::Args exec_args;
1396   ExecutorArgsFromOptions(opts, call_frame, &exec_args);
1397   return item->exec->Run(exec_args);
1398 }
1399 
IsStateful(const string & func) const1400 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
1401   const OpDef* op_def;
1402   const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
1403   return s.ok() && op_def->is_stateful();
1404 }
1405 
DebugString(Handle handle)1406 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
1407   Item* item = nullptr;
1408   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1409   Status s = GetOrCreateItem(local_handle, &item);
1410   if (s.ok()) {
1411     if (item->graph) {
1412       return tensorflow::DebugString(item->graph.get());
1413     } else {
1414       return tensorflow::DebugString(item->func_graph->graph);
1415     }
1416   } else {
1417     return s.ToString();
1418   }
1419 }
1420 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)1421 Status FunctionLibraryRuntimeImpl::Clone(
1422     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1423     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1424     FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
1425   TF_RETURN_IF_ERROR(parent_->Clone(env_, graph_def_version_,
1426                                     optimizer_.options(), out_lib_def, out_pflr,
1427                                     skip_flib_def));
1428   *out_flr = (*out_pflr)->GetFLR(device_->name());
1429   if (*out_flr != nullptr) {
1430     return Status::OK();
1431   } else {
1432     return errors::Internal("Cloning FunctionLibraryRuntime failed.");
1433   }
1434 }
1435 
1436 namespace {
1437 
1438 struct CustomCreatorSingleton {
1439   mutex mu;
1440   CustomKernelCreator* custom_creator = nullptr;
1441 
Settensorflow::__anon4439174e1111::CustomCreatorSingleton1442   void Set(CustomKernelCreator* cb) {
1443     mutex_lock l(mu);
1444     custom_creator = cb;
1445   }
1446 
Gettensorflow::__anon4439174e1111::CustomCreatorSingleton1447   CustomKernelCreator* Get() {
1448     mutex_lock l(mu);
1449     return custom_creator;
1450   }
1451 };
1452 
GetCustomCreatorSingleton()1453 CustomCreatorSingleton* GetCustomCreatorSingleton() {
1454   static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
1455   return ccs;
1456 }
1457 
1458 }  // namespace
1459 
GetDefaultCustomKernelCreator()1460 const CustomKernelCreator* GetDefaultCustomKernelCreator() {
1461   return GetCustomCreatorSingleton()->Get();
1462 }
1463 
RegisterDefaultCustomKernelCreator(CustomKernelCreator * c)1464 void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
1465   GetCustomCreatorSingleton()->Set(c);
1466 }
1467 
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)1468 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1469     const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
1470     Device* device, int graph_def_version,
1471     const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
1472     const OptimizerOptions& optimizer_options,
1473     const SessionMetadata* session_metadata,
1474     ProcessFunctionLibraryRuntime* parent) {
1475   return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
1476       device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
1477       optimizer_options, session_metadata, parent));
1478 }
1479 
1480 class SymbolicGradientHelper {
1481  public:
SymbolicGradientHelper(const FunctionBody & f)1482   explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1483   ~SymbolicGradientHelper() = default;
1484 
1485   std::unique_ptr<FunctionBody> Compute();
1486 
1487  private:
1488   const FunctionBody* fbody_;
1489 
1490   // Makes a copy of fbody_ in gbody.
1491   void Copy(FunctionBody* gbody);
1492 
1493   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1494 };
1495 
Copy(FunctionBody * gbody)1496 void SymbolicGradientHelper::Copy(FunctionBody* gbody) {
1497   const Graph& src = *(fbody_->graph);
1498   gbody->graph = new Graph(src.op_registry());
1499   Graph* dst = gbody->graph;
1500 
1501   std::vector<Node*> node_map(src.num_node_ids());
1502 
1503   // Copy just the fdef attributes (copy '_noinline' and other similar flags to
1504   // the gradient function body).
1505   *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr();
1506 
1507   // Copy the nodes.
1508   node_map[src.source_node()->id()] = dst->source_node();
1509   node_map[src.sink_node()->id()] = dst->sink_node();
1510   for (Node* n : src.op_nodes()) {
1511     node_map[n->id()] = dst->CopyNode(n);
1512   }
1513 
1514   // Copy the edges.
1515   for (const Edge* e : src.edges()) {
1516     Node* src_copy = node_map[e->src()->id()];
1517     Node* dst_copy = node_map[e->dst()->id()];
1518     dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1519   }
1520 
1521   // Save inputs in copied graph.
1522   CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1523   gbody->arg_types = fbody_->arg_types;
1524   for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1525     gbody->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1526   }
1527 
1528   // Save outputs in copied graph.
1529   CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1530   gbody->ret_types = fbody_->ret_types;
1531   for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1532     gbody->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1533   }
1534 }
1535 
Compute()1536 std::unique_ptr<FunctionBody> SymbolicGradientHelper::Compute() {
1537   FunctionBody* gbody = new FunctionBody;
1538   Copy(gbody);  // copy fbody_ into gbody.
1539 
1540   Graph* g = gbody->graph;
1541 
1542   const int num_y = static_cast<int>(gbody->ret_nodes.size());
1543 
1544   // Populate 'y_node_outputs_' with node function body outputs.
1545   // Populate 'y_grad_nodes' with initial gradient nodes for each return node
1546   // of the original function body (these will be 'arg' nodes in the function
1547   // gradient body).
1548   std::vector<NodeOut> y_node_outputs;
1549   y_node_outputs.reserve(num_y);
1550   std::vector<NodeOut> y_grad_node_outputs;
1551   y_grad_node_outputs.reserve(num_y);
1552   for (int i = 0; i < num_y; ++i) {
1553     Node* y = gbody->ret_nodes[i];
1554     y_node_outputs.push_back({y, 0});
1555     DCHECK_EQ(y->type_string(), kRetOp);
1556     const DataType dtype = y->input_type(0);
1557     const int index = static_cast<int>(gbody->arg_nodes.size());
1558     Node* dy = AddArg(g, dtype, index);
1559     gbody->arg_types.push_back(dtype);
1560     gbody->arg_nodes.push_back(dy);
1561     y_grad_node_outputs.push_back({dy, 0});
1562   }
1563 
1564   // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1565   const size_t num_x = fbody_->arg_nodes.size();
1566   std::vector<NodeOut> x_node_outputs;
1567   x_node_outputs.reserve(num_x);
1568   for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1569     x_node_outputs.push_back({gbody->arg_nodes[i], 0});
1570   }
1571 
1572   // Call AddSymbolicGradients which will add nodes to graph 'g' that
1573   // compute the function gradient (adding an entry in 'x_grad_node_outputs'
1574   // for each node in 'x_node_outputs').
1575   std::vector<NodeOut> x_grad_node_outputs;
1576   TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1577                                    y_grad_node_outputs, &x_grad_node_outputs,
1578                                    g));
1579 
1580   // Remove the old return nodes from the function body.
1581   for (Node* n : gbody->ret_nodes) {
1582     g->RemoveNode(n);
1583   }
1584   gbody->ret_types = fbody_->arg_types;
1585   // TODO(apassos): use the right dtype for gradients of  resource variables
1586   for (int i = 0; i < gbody->ret_types.size(); ++i) {
1587     if (gbody->ret_types[i] == DT_RESOURCE) {
1588       gbody->ret_types[i] = DT_FLOAT;
1589     }
1590   }
1591   gbody->ret_nodes.clear();
1592   // Add new return nodes to the function gradient body for each node
1593   // in 'x_grad_nodes'.
1594   const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1595   for (int i = 0; i < arg_types_size; ++i) {
1596     Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1597     Node* ret = AddRet(g, grad, i);
1598     gbody->ret_nodes.push_back(ret);
1599   }
1600 
1601   return std::unique_ptr<FunctionBody>(gbody);
1602 }
1603 
SymbolicGradient(const FunctionBody & f)1604 std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
1605   return SymbolicGradientHelper(f).Compute();
1606 }
1607 
1608 }  // end namespace tensorflow
1609