• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
18 
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/platform/platform.h"
24 // clang-format on
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/variant.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/registration/registration.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/lib/gtl/flatmap.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/protobuf.h"
44 #include "tensorflow/core/platform/threadpool_interface.h"
45 #include "tensorflow/core/protobuf/config.pb.h"
46 #if !defined(IS_MOBILE_PLATFORM)
47 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
48 #endif  // IS_MOBILE_PLATFORM
49 
50 namespace tensorflow {
51 
52 class CancellationManager;
53 class CollectiveExecutor;
54 class DeviceSet;
55 class Graph;
56 class GraphDef;
57 class OpKernel;
58 class ProcessFunctionLibraryRuntime;
59 class ResourceMgr;
60 class Rendezvous;
61 class ScopedStepContainer;
62 class StepStatsCollectorInterface;
63 class Node;
64 
65 // FunctionDefHelper::Create is a convenient helper to construct a
66 // FunctionDef proto.
67 // E.g.,
68 //   FunctionDef my_func = FunctionDefHelper::Create(
69 //     "my_func_name",
70 //     {"x:T", "y:T" /* one string per argument */},
71 //     {"z:T" /* one string per return value */},
72 //     {"T: {float, double}" /* one string per attribute  */},
73 //     {
74 //        {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
75 //        /* one entry per function node */
76 //     },
77 //     /* Mapping between function returns and function node outputs. */
78 //     {{"z", "o:z"}});
79 //
80 // For the old Function::Node approach, use FunctionDefHelper::Define()
81 // E.g.,
82 //   FunctionDef my_func = FunctionDefHelper::Define(
83 //     "my_func_name",
84 //     {"x:T", "y:T" /* one string per argument */},
85 //     {"z:T" /* one string per return value */},
86 //     {"T: {float, double}" /* one string per attribute  */},
87 //     {
88 //        {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
89 //        /* one entry per function node */
90 //     });
91 class FunctionDefHelper {
92  public:
93   // AttrValueWrapper has copy constructors for the type T so that
94   // it's easy to construct a simple AttrValue proto.
95   //
96   // If T is a string type (const char*, string, or StringPiece), and
97   // it starts with "$", we construct a AttrValue of "placeholder".
98   //
99   // E.g.,
100   //   std::<string, AttrValueWrapper> x = {"T", "$T"}
101   // is a named attr value placeholder.
102   struct AttrValueWrapper {
103     AttrValue proto;
104 
AttrValueWrapperAttrValueWrapper105     AttrValueWrapper() {}
106 
107     template <typename T>
AttrValueWrapperAttrValueWrapper108     AttrValueWrapper(T val) {  // NOLINT(runtime/explicit)
109       SetAttrValue(val, &proto);
110     }
111 
112    private:
113     void InitFromString(StringPiece val);
114   };
115 
116   // Constructs an AttrValue.func given the "name" and "attrs".
117   static AttrValueWrapper FunctionRef(
118       const std::string& name,
119       gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
FunctionRef(const std::string & name)120   static AttrValueWrapper FunctionRef(const std::string& name) {
121     return FunctionRef(name, {});
122   }
123 
124   // Node is used to construct FunctionDef.Node using initialization
125   // lists. E.g.,
126   //  Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}};  // z = x * y
127   //
128   // If the op has no inputs, then name is be specified.
129   //  Node n = {{}, "AssignVariable", {"resource", "val"}, {{"dtype",
130   //  "DT_FLOAT"},
131   //            {"update0"}, "CPU:0", "update1"}}
132   struct Node {
133     // When constructing a NodeDef, the first entry in ret is used as
134     // the node name, the remaining values are ignored.
135     std::vector<string> ret;
136     std::string op;
137     std::vector<string> arg;
138     std::vector<std::pair<string, AttrValueWrapper>> attr;
139     std::vector<string> dep;
140     std::string device;
141 
142     // Required if the op has zero outputs. Otherwise, ret[0] used as name if
143     // name is left empty.
144     std::string name;
145 
GetNameNode146     std::string GetName() const {
147       if (!name.empty()) return name;
148       CHECK(!ret.empty());
149       return ret[0];
150     }
151 
152     NodeDef ToNodeDef() const;
153   };
154 
155   // Creates a FunctionDef from the given parameters. Node inputs must use
156   // function encoding (node_name:output_name[:output_index]).
157   // - `ret_def` holds a mapping from the function output names from `out_def`
158   //   to the node outputs from `node_def`.
159   // - `control_ret_def` holds a mapping from the function control
160   //   output names to the nodes from `node_def`.
161   static FunctionDef Create(
162       const std::string& function_name, gtl::ArraySlice<string> in_def,
163       gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
164       gtl::ArraySlice<Node> node_def,
165       gtl::ArraySlice<std::pair<string, string>> ret_def,
166       gtl::ArraySlice<std::pair<string, string>> control_ret_def);
167 
168   // Creates a FunctionDef from the given parameters. Node inputs must use
169   // function encoding (node_name:output_name[:output_index]).
170   // - `ret_def` holds a mapping from the function output names from `out_def`
171   //   to the node outputs from `node_def`.
172   static FunctionDef Create(const std::string& function_name,
173                             gtl::ArraySlice<string> in_def,
174                             gtl::ArraySlice<string> out_def,
175                             gtl::ArraySlice<string> attr_def,
176                             gtl::ArraySlice<Node> node_def,
177                             gtl::ArraySlice<std::pair<string, string>> ret_def);
178 
179   // TODO(josh11b): Get rid of these and transition to the one above.
180   static FunctionDef Define(const std::string& function_name,
181                             gtl::ArraySlice<string> arg_def,
182                             gtl::ArraySlice<string> ret_def,
183                             gtl::ArraySlice<string> attr_def,
184                             gtl::ArraySlice<Node> node_def);
185 
186   // Defines an anonymous function. I.e., its name is not relevant.
187   static FunctionDef Define(gtl::ArraySlice<string> arg_def,
188                             gtl::ArraySlice<string> ret_def,
189                             gtl::ArraySlice<string> attr_def,
190                             gtl::ArraySlice<Node> node_def);
191 
192   // Helpers to construct a constant scalar.
193   template <typename T>
Const(const std::string & name,const T & val)194   static Node Const(const std::string& name, const T& val) {
195     Node n = {{name}, "Const"};
196     const DataType dtype = DataTypeToEnum<T>::value;
197     n.attr.push_back({"dtype", dtype});
198     Tensor t(dtype, TensorShape({}));
199     t.scalar<T>()() = val;
200     n.attr.push_back({"value", t});
201     return n;
202   }
203 
204   template <typename T>
Const(const std::string & name,gtl::ArraySlice<T> vals)205   static Node Const(const std::string& name, gtl::ArraySlice<T> vals) {
206     Node n = {{name}, "Const"};
207     const DataType dtype = DataTypeToEnum<T>::value;
208     n.attr.push_back({"dtype", dtype});
209     int64_t num = vals.size();
210     Tensor t(dtype, TensorShape({num}));
211     for (size_t i = 0; i < vals.size(); ++i) {
212       t.flat<T>()(i) = vals[i];
213     }
214     n.attr.push_back({"value", t});
215     return n;
216   }
217 };
218 
219 template <>
AttrValueWrapper(const char * val)220 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
221   InitFromString(val);
222 }
223 
224 template <>
AttrValueWrapper(const std::string & val)225 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
226     const std::string& val) {
227   InitFromString(val);
228 }
229 
230 template <>
AttrValueWrapper(StringPiece val)231 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
232   InitFromString(val);
233 }
234 
235 // Instantiate a function.
236 //
237 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
238 // containing placeholders.  InstantiateFunction binds these
239 // placeholders and produces an instantiated function encoded in
240 // "result.gdef". The value to substitute a placeholder is given by
241 // "attr_values", which is a map from a placeholder name to an attr
242 // value.
243 //
244 // InstantiateFunction calls "get_function" to find signatures of other
245 // functions and primitive ops.
246 
247 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
248 // and opdef is filled with a pointer to the corresponding signature
249 // (a OpDef proto). Otherwise, returns an error.
250 typedef std::function<Status(const string&, const OpDef**)>
251     GetFunctionSignature;
252 
253 struct InstantiationResult {
254   DataTypeVector arg_types;
255   DataTypeVector ret_types;
256   std::vector<NodeDef> nodes;
257 };
258 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
259                            GetFunctionSignature get_function,
260                            InstantiationResult* result);
261 
262 // Returns a debug string for a function definition.
263 //
264 // The returned text is multiple-line. It is intended to be
265 // human-readable rather than being friendly to parsers. It is _NOT_
266 // intended to be the canonical string representation of "func_def".
267 // Particularly, it may not include all information presented in
268 // "func_def" (e.g., comments, description of the function arguments,
269 // etc.)
270 std::string DebugString(const FunctionDef& func_def);
271 std::string DebugString(const GraphDef& instantiated_func_def);
272 std::string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
273 
274 // Returns a debug string for a top level graph (the main program and
275 // its supporting functions defined in its library).
276 std::string DebugStringWhole(const GraphDef& gdef);
277 
278 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
279 // of NodeDefs doesn't matter.
280 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
281 
282 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
283 // In other words, if two fdefs compare equal, their hash values will be the
284 // same.
285 uint64 FunctionDefHash(const FunctionDef& fdef);
286 
287 class CallFrameInterface {
288  public:
~CallFrameInterface()289   virtual ~CallFrameInterface() {}
290 
291   virtual size_t num_args() const = 0;
292   virtual size_t num_retvals() const = 0;
293 
294   virtual Status GetArg(int index, const Tensor** val) = 0;
295 
296   // Optimized implementation of `GetArg()` that allows the caller to take
297   // ownership of the tensor. This method may only be called once per
298   // value of `index` and `CallFrameInterface` instance.
299   //
300   // REQUIRES: `this->CanConsumeArg(index) == true`.
ConsumeArg(int index,Tensor * val)301   virtual void ConsumeArg(int index, Tensor* val) {
302     LOG(ERROR) << "This `CallFrameInterface` implementation does not support "
303                   "consuming arguments.";
304   }
CanConsumeArg(int index)305   virtual bool CanConsumeArg(int index) const { return false; }
306 
307   virtual Status SetRetval(int index, const Tensor& val) = 0;
308 };
309 
310 // Represents a function call frame. I.e., the data structure used to
311 // pass arguments to a function and retrieve its results.
312 //
313 // Runtime must arrange accesses to one FunctionCallFrame s.t.
314 //   1. SetArgs() happens before any GetArg();
315 //   2. GetRetvals happens after all SetRetval();
316 class FunctionCallFrame : public CallFrameInterface {
317  public:
318   FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
319   ~FunctionCallFrame() override;
320 
321   // Caller methods.
322   Status SetArgs(gtl::ArraySlice<Tensor> args);
323   Status GetRetvals(std::vector<Tensor>* rets) const;
324 
325   // Moves the return values from the frame to rets. If allow_dead_tensors is
326   // false it will fail if any of the retvals do not have a value.
327   Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors);
328 
num_args()329   size_t num_args() const override { return arg_types_.size(); }
num_retvals()330   size_t num_retvals() const override { return ret_types_.size(); }
331 
332   // Callee methods.
333   Status GetArg(int index, const Tensor** val) override;
334   Status SetRetval(int index, const Tensor& val) override;
335 
336  private:
337   DataTypeVector arg_types_;
338   DataTypeVector ret_types_;
339   gtl::InlinedVector<Tensor, 4> args_;
340   struct Retval {
341     bool has_val = false;
342     Tensor val;
343   };
344   gtl::InlinedVector<Retval, 4> rets_;
345 
346   TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
347 };
348 
349 // Language agnostic stack traces.
350 class AbstractStackTrace {
351  public:
352   struct TracePrintingOptions {
353     // Show inline the contents of each stack line.
354     bool show_line_contents = false;
355 
356     // Drop the common largest prefix of all filenames in stack frames.
357     bool filter_common_prefix = false;
358 
359     // Do not show internal frames.
360     bool drop_internal_frames = false;
361   };
362 
~AbstractStackTrace()363   virtual ~AbstractStackTrace() {}
364 
365   // The returned span is alive as long as the AbstractStackTrace is alive.
366   virtual absl::Span<StackFrame const> ToFrames() const = 0;
367 
368   // Returns the last stack frame from user code, attempting to ignore the
369   // framework code. Returns an empty frame if no such stack frame was found.
370   virtual StackFrame LastUserFrame() const = 0;
371   virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
372 };
373 
374 using StackTracesMap =
375     std::unordered_map<std::string,
376                        std::shared_ptr<tensorflow::AbstractStackTrace>>;
377 
378 // Helper to maintain a map between function names in a given
379 // FunctionDefLibrary and function definitions.
380 //
381 // This class is thread-safe.
382 class FunctionLibraryDefinition : public OpRegistryInterface {
383  public:
384   // Ops created for function arguments bear the name given by `kArgOp`; those
385   // created for return values bear the name given by `kRetOp`.
386   static constexpr const char* const kArgOp = "_Arg";
387   static constexpr const char* const kDeviceArgOp = "_DeviceArg";
388   static constexpr const char* const kRetOp = "_Retval";
389   static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
390   static constexpr const char* const kIntsOnDeviceAttr =
391       "experimental_ints_on_device";
392   static constexpr const char* const kSharedRendezvousAttr =
393       "shared_rendezvous";
394 
395   static constexpr const char* const kGradientOp = "SymbolicGradient";
396   static constexpr const char* const kFuncAttr = "f";
397 
398   // Note: This constructor grabs `lib_def`'s lock in shared mode.
399   FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
400   FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
401                             const FunctionDefLibrary& lib_def = {});
402   ~FunctionLibraryDefinition() override;
403 
404   FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
405       delete;
406 
407   // Returns True if the library contains `func`, False otherwise.
408   bool Contains(const std::string& func) const;
409 
410   // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
411   // returns its definition proto.
412   //
413   // NB: This function returns a borrowed pointer, which can be invalidated by a
414   // subsequent call to `ReplaceFunction()` with the given name.
415   const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_);
416 
417   // Adds function definition 'fdef' to this function library.
418   // Returns status 'ok' on success, or error otherwise. This is a no-op if
419   // 'fdef' already exists in this function library.
420   // If 'fdef' is successfully added to the library, it will be accessible
421   // from 'LookUp' and included in the proto returned by 'ToProto'.
422   // This operation is atomic.
423   //
424   // Associates `graph` with a function `func_name`. Lifetime assumption:
425   // `graph` has to outlive all instantiated graphs.
426   Status AddFunctionDef(const FunctionDef& fdef,
427                         const StackTracesMap& stack_traces = {})
428       TF_LOCKS_EXCLUDED(mu_);
429 
430   // Adds gradient definition 'grad' to this function library.
431   // This is a no-op if 'grad' already exists in this function library.
432   // If 'grad' is successfully added, it will be accessible via 'FindGradient'
433   // and included in the proto returned by 'ToProto'.
434   // This operation is atomic.
435   Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
436 
437   // Replaces the function corresponding to `func` with `fdef`. Returns
438   // a non-OK status if "func" was not found in the library, OK otherwise.
439   // Please be careful when replacing function: make sure all previous pointers
440   // returned by `Find()` are no longer in use.
441   Status ReplaceFunction(const std::string& func, const FunctionDef& fdef,
442                          const StackTracesMap& stack_traces = {})
443       TF_LOCKS_EXCLUDED(mu_);
444 
445   // Replaces the gradient corresponding to `grad.function_name()`. Returns
446   // a non-OK status if "grad.function_name()" was not found in the library, OK
447   // otherwise.
448   Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
449 
450   // Removes the function corresponding to 'func'. Returns a non-OK status if
451   // 'func' was not found in the library, OK otherwise.
452   // Please be careful when removing function: make sure there are no other
453   // nodes using the function, and all previous pointers returned by `Find()`
454   // are no longer in use.
455   Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
456 
457   // Removes all the functions and gradient functions.
458   void Clear() TF_LOCKS_EXCLUDED(mu_);
459 
460   // Adds the functions and gradients in 'other' to this function library.
461   // Duplicate functions and gradients are ignored.
462   // This operation is atomic.
463   Status AddLibrary(const FunctionLibraryDefinition& other)
464       TF_LOCKS_EXCLUDED(mu_);
465 
466   // Adds the functions and gradients in 'lib_def' to this function library.
467   // Duplicate functions and gradients are ignored.
468   // This operation is atomic.
469   Status AddLibrary(const FunctionDefLibrary& lib_def) TF_LOCKS_EXCLUDED(mu_);
470 
471   // If the gradient function for 'func' is specified explicitly in
472   // the library, returns the gradient function name.  Otherwise,
473   // returns an empty string.
474   std::string FindGradient(const std::string& func) const
475       TF_LOCKS_EXCLUDED(mu_);
476 
477   // OpRegistryInterface method. Useful for constructing a Graph.
478   //
479   // If "op" is defined in the library, returns its signature.
480   // Otherwise, assume "op" is a primitive op and returns its op
481   // signature and shape inference function.
482   //
483   // NB: This function outputs a borrowed pointer, which can be invalidated by a
484   // subsequent call to `ReplaceFunction()` with the given name.
485   Status LookUp(const std::string& op_type_name,
486                 const OpRegistrationData** op_reg_data) const override
487       TF_LOCKS_EXCLUDED(mu_);
488 
489   // Generates new function name with the specified prefix that is unique
490   // across this library.
491   std::string UniqueFunctionName(StringPiece prefix) const
492       TF_LOCKS_EXCLUDED(mu_);
493 
494   // Given a node def 'ndef', inspects attributes of the callee
495   // function to derive the attribute 'value' for 'attr'. Returns OK
496   // iff the attribute is given by the function's definition.
497   // TODO(irving): Remove; keep only the const Node& version.
498   template <typename T>
499   Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const;
500 
501   // Given a node, inspects attributes of the callee function to derive the
502   // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
503   // function's definition.
504   template <typename T>
505   Status GetAttr(const Node& node, const std::string& attr, T* value) const;
506 
507   // Returns a proto representation of the state of this function library.
508   FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_);
509 
num_functions()510   size_t num_functions() const {
511     tf_shared_lock l(mu_);
512     return function_defs_.size();
513   }
514 
515   // Returns all the function names in the FunctionLibraryDefinition.
516   std::vector<string> ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_);
517 
default_registry()518   const OpRegistryInterface* default_registry() const {
519     return default_registry_;
520   }
set_default_registry(const OpRegistryInterface * registry)521   void set_default_registry(const OpRegistryInterface* registry) {
522     default_registry_ = registry;
523   }
524 
525   // Returns a copy of `*this` with only the subset of functions that are
526   // reachable from the nodes of `graph` or `func`.
527   FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
528   FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
529 
530   // Copies the function named `func` from `other` to this
531   // FunctionLibraryDefinition.
532   // REQUIRES: `this->default_registry() == other.default_registry()`.
533   // Returns OK on success, or error otherwise. This is a no-op if a function
534   // name `func` already exists in this function library, and has the same
535   // implementation as in `other`. If the implementations conflict, an invalid
536   // argument error is returned.
537   Status CopyFunctionDefFrom(const std::string& func,
538                              const FunctionLibraryDefinition& other)
539       TF_LOCKS_EXCLUDED(mu_);
540 
541   // Returns graph with debug stack traces for the given function, or `nullptr`
542   // if none found.
GetStackTraces(const std::string & func_name)543   const StackTracesMap& GetStackTraces(const std::string& func_name) const {
544     tf_shared_lock l(mu_);
545     std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name);
546     if (entry) {
547       return entry->stack_traces;
548     }
549     static const auto* empty_map = new StackTracesMap;
550     return *empty_map;
551   }
552 
553  private:
554   // Shape inference for functions is handled separately by ShapeRefiner.
555 
556   struct FunctionDefAndOpRegistration {
557     explicit FunctionDefAndOpRegistration(
558         const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {});
559 
560     const FunctionDef fdef;
561     const OpRegistrationData op_registration_data;
562     const StackTracesMap stack_traces;
563   };
564 
565   std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
566       const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_);
567   std::string FindGradientHelper(const std::string& func) const
568       TF_SHARED_LOCKS_REQUIRED(mu_);
569 
570   Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
571                    bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
572 
573   // Same as AddFunctionDef/AddGradientDef except these methods set
574   // `added` to true if the `fdef`/`grad` were actually added to this.
575   Status AddFunctionDefHelper(const FunctionDef& fdef,
576                               const StackTracesMap& stack_traces, bool* added)
577       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
578   Status AddGradientDefHelper(const GradientDef& grad, bool* added)
579       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
580 
581   // Helper function for GetAttr. Returns the FunctionDef* to get the
582   // attr from.
583   const FunctionDef* GetAttrImpl(const NodeDef& ndef) const
584       TF_LOCKS_EXCLUDED(mu_);
585 
586   // Remove all functions in `funcs` and all gradients of functions in
587   // `funcs_with_grads` from this library.
588   Status Remove(const std::vector<string>& funcs,
589                 const std::vector<string>& funcs_with_grads)
590       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
591 
592   // Remove `func` from the library. Returns non-OK Status unless `func` is in
593   // the library. This should only be called when there is a guarantee that the
594   // function being removed hasn't been retrieved with `Find`.
595   Status RemoveFunctionHelper(const std::string& func)
596       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
597 
598   // Remove gradient of function `func` from the library. Returns non-OK Status
599   // unless `func` has a gradient.
600   Status RemoveGradient(const std::string& func)
601       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
602 
603   mutable mutex mu_;
604   const OpRegistryInterface* default_registry_;
605   gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>>
606       function_defs_ TF_GUARDED_BY(mu_);
607   gtl::FlatMap<string, string> func_grad_ TF_GUARDED_BY(mu_);
608 };
609 
610 // Forward declare. Defined in common_runtime/function.h
611 struct FunctionBody;
612 
613 // Forward declare. Defined in common_runtime/device.h
614 class Device;
615 // Forward declare. Defined in common_runtime/device_mgr.h
616 class DeviceMgr;
617 
618 // Index of an _Arg node.
619 struct FunctionArgIndex {
FunctionArgIndexFunctionArgIndex620   explicit FunctionArgIndex(const int index) : index(index) {}
FunctionArgIndexFunctionArgIndex621   FunctionArgIndex(const int index, const int sub_index)
622       : index(index), sub_index(sub_index) {}
623 
624   // The value of the attribute "Index" of the _Arg node.
625   int index;
626   // Set only when the _Arg node represents multiple arguments (e.g. an _Arg
627   // node is replicated to multiple devices/subgraphs). Use sub-index to
628   // distinguish arguments with the same index.
629   int sub_index = -1;
630 };
631 
632 class FunctionLibraryRuntime {
633  public:
~FunctionLibraryRuntime()634   virtual ~FunctionLibraryRuntime() {}
635 
636   // Instantiate a function with the given "attrs".
637   //
638   // Returns OK and fills in "handle" if the instantiation succeeds.
639   // Otherwise returns an error and "handle" is undefined.
640   struct InstantiateOptions {
641     // The canonical device name of the device on which the function
642     // should be instantiated. If empty, the function will be
643     // instantiated on the local device.
644     std::string target;
645 
646     // Should the function be instantiated as a multi-device function?
647     bool is_multi_device_function = false;
648 
649     // If true, graph passes will be skipped when instantiating the function
650     // since they have already run on the main function side.
651     bool is_component_function = false;
652 
653     // For multi-device functions, a vector of canonical device names for
654     // function's inputs. The device of resource inputs must be the device
655     // backing the resource, not the CPU device backing the resource handle.
656     // Must have the same length as number of inputs to the function.
657     std::vector<string> input_devices;
658 
659     // For multi-device functions, a vector of canonical device names for
660     // function's outputs.
661     //
662     // (a) If specified (must have the same length as number of outputs):
663     //
664     // Specified devices will be assigned to Retval nodes inserted into the
665     // function body graph in place of function outputs. It is allowed to
666     // specify output device as empty string, in this case Retval device
667     // assignment will be inferred later when function graph will be placed
668     // before partitioning (this is required for resource outputs). Placer will
669     // respect colocation constraints.
670     //
671     // (b) If not specified:
672     //
673     // Function runtime will infer Retval device by following input edges, until
674     // it will reach a node with a device specification. This device
675     // specification must identify a unique device, i.e. a general specification
676     // like "job:foo" matching multiple devices will result in an error.
677     //
678     // IMPORTANT: Resource outputs
679     //
680     // Multi device functions might return resources on a devices different from
681     // the function call device. If output device is not specified for the
682     // resource output, and node producing that resource is a function call,
683     // runtime will leave device specification empty and will rely on Placer to
684     // infer correct device.
685     std::vector<string> output_devices;
686 
687     // If set, it indicates the original output indices of a component function.
688     absl::optional<std::vector<int>> ret_indices = absl::nullopt;
689 
690     // Maps from a CompositeDevice name to a list of underlying physical
691     // devices.
692     absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
693 
694     // This interface is EXPERIMENTAL and subject to change.
695     //
696     // For multi-device functions, a mapping from _Arg node index to type and
697     // shape for input resources.
698     // REQUIRES: if input_resource_dtypes_and_shapes.count(i) > 0 then i-th
699     // argument type must be DT_RESOURCE.
700     std::unordered_map<int, DtypeAndPartialTensorShape>
701         input_resource_dtypes_and_shapes;
702 
703     // This interface is EXPERIMENTAL and subject to change.
704     //
705     // If non-null, the runtime will use `lib_def` to resolve function(s) named
706     // in `function_name` and `attrs`. Otherwise, the runtime will use its
707     // internal library.
708     //
709     // NOTE(mrry): If provided, all functions defined in `lib_def` must be
710     // self-contained, and cannot refer to functions defined in other libraries.
711     const FunctionLibraryDefinition* lib_def = nullptr;
712 
713     // This interface is EXPERIMENTAL and subject to change.
714     //
715     // If non-empty, the runtime will use `state_handle` to identify
716     // cached state related the instantiated function. Two functions
717     // of the same name and attrs, instantiated with the same
718     // `state_handle` will have the same handle and share the same
719     // state (in stateful kernels); and two functions with different
720     // values for `state_handle` will have independent state.
721     std::string state_handle;
722 
723     // This interface is EXPERIMENTAL and subject to change.
724     //
725     // Instantiates the function using an executor of the given type. If empty,
726     // the default TensorFlow executor will be used.
727     std::string executor_type;
728 
729     // If true, the runtime will attempt to create kernels for the function at
730     // instantiation time, rather than on the first run. This can be used to
731     // surface errors earlier.
732     bool create_kernels_eagerly = false;
733 
734     // This interface is EXPERIMENTAL and subject to change.
735     //
736     // Instantiates the function with the provided config_proto.
737     ConfigProto config_proto;
738 
739     // If provided, this optimization function will be invoked before
740     // the placer for multi-device functions.
741     std::function<Status(std::vector<string> /*ret_node_names*/,
742                          std::vector<string> /*keep_node_names*/,
743                          FunctionLibraryDefinition*, const DeviceSet&,
744                          Device* /*cpu_device*/, std::unique_ptr<Graph>*)>
745         optimize_graph_fn;
746 
747     // If set, partitioned functions will be added to `graph_collector`.
748     // `graph_collector` must be alive during the call to Instantiate.
749     GraphCollector* graph_collector = nullptr;
750 
751     // Indicates whether the multi-device function backend should default the
752     // placement of ops without request device to `target`.
753     bool default_device_to_target = true;
754 
755     // If true, the optimized Graph will be stored so that
756     // `FunctionLibraryRuntime::DebugString(handle)` contains the optimized
757     // Graph. Otherwise, the unoptimized function Graph will be returned.
758     bool include_optimized_graph_in_debug_string = false;
759 
760     // If true, the function library runtime cache the function instantiation.
761     bool use_function_cache = false;
762   };
763   typedef uint64 Handle;
764   virtual Status Instantiate(const std::string& function_name, AttrSlice attrs,
765                              const InstantiateOptions& options,
766                              Handle* handle) = 0;
Instantiate(const std::string & function_name,AttrSlice attrs,Handle * handle)767   Status Instantiate(const std::string& function_name, AttrSlice attrs,
768                      Handle* handle) {
769     auto opts = absl::make_unique<InstantiateOptions>();
770     return Instantiate(function_name, attrs, *opts, handle);
771   }
772 
773   // Releases state associated with the handle.
774   virtual Status ReleaseHandle(Handle handle) = 0;
775 
776   // Returns the function body for the instantiated function given its
777   // handle 'h'. Returns nullptr if "h" is not found.
778   //
779   // *this keeps the ownership of the returned object, which remains alive
780   // as long as *this.
781   virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
782 
783   // Returns the return types for the function identified by handle `h`.
784   virtual Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0;
785 
786   // Asynchronously invokes the instantiated function identified by
787   // "handle".
788   //
789   // If function execution succeeds, "done" is called with OK and
790   // "*rets" is filled with the function's return values. Otherwise,
791   // "done" is called with an error status.
792   //
793   // Does not take ownership of "rets".
794   // In the cross-process scenario, runner isn't used for making the Async
795   // RPC calls.
796   struct Options {
OptionsOptions797     Options() {}
OptionsOptions798     explicit Options(const int64_t step_id) : step_id(step_id) {}
799     // Choose a step ID that is guaranteed not to clash with any
800     // Session-generated step ID. DirectSession only generates
801     // non-negative step IDs (contiguous, starting from 0), and
802     // MasterSession generates 56-bit random step IDs whose MSB is
803     // always 0, so a negative random step ID should suffice.
804     const int64 step_id = -std::abs(static_cast<int64>(random::New64()));
805 
806     // op_id of the function running in eager mode. Set when we want to copy
807     // remote outputs lazily. All components of a remote multi-device function
808     // should use the same op_id, in order to correctly map remote output
809     // tensors to the remote TensorHandles in the default device.
810     absl::optional<int64> op_id = absl::nullopt;
811 
812     RendezvousInterface* rendezvous = nullptr;
813     CancellationManager* cancellation_manager = nullptr;
814     CollectiveExecutor* collective_executor = nullptr;
815     ScopedStepContainer* step_container = nullptr;
816     StepStatsCollectorInterface* stats_collector = nullptr;
817     CoordinationServiceAgent* coordination_service_agent = nullptr;
818 
819     std::function<void(std::function<void()>)>* runner = nullptr;
820 
821     // Parameters for remote function execution.
822     bool remote_execution = false;
823     std::string source_device = "";  // Fully specified device name.
824 
825     // Allocator attributes specifying where the args are / rets should be put.
826     // These should either be {} or match the length of args / retvals. If {},
827     // the default allocator attributes will be assumed for all args / retvals.
828     std::vector<AllocatorAttributes> args_alloc_attrs;
829     std::vector<AllocatorAttributes> rets_alloc_attrs;
830 
831     // If true, we create a new IntraProcessRendezvous, else use the existing
832     // one.
833     bool create_rendezvous = false;
834 
835     // If True, allow returning dead tensors.
836     bool allow_dead_tensors = false;
837 
838     // If True, hint that all kernels should be treated as "inexpensive", and
839     // hence executed on the scheduling thread.
840     bool run_all_kernels_inline = false;
841 
842     // If not null, use this thread pool for intra op scheduling.
843     thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
844 
845     // Returns a human readable representation of this.
846     std::string DebugString() const;
847   };
848   typedef std::function<void(const Status&)> DoneCallback;
849   virtual void Run(const Options& opts, Handle handle,
850                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
851                    DoneCallback done) = 0;
852   virtual void Run(const Options& opts, Handle handle,
853                    CallFrameInterface* call_frame, DoneCallback done) = 0;
854 
855   virtual Status RunSync(Options opts, Handle handle,
856                          gtl::ArraySlice<Tensor> args,
857                          std::vector<Tensor>* rets) = 0;
858   virtual Status RunSync(Options opts, Handle handle,
859                          CallFrameInterface* call_frame) = 0;
860 
861   // Creates a "kernel" for the given NodeProperties "props".
862   //
863   // If succeeds, returns OK and the caller takes the ownership of the
864   // returned "*kernel". Otherwise, returns an error.
865   virtual Status CreateKernel(
866       const std::shared_ptr<const NodeProperties>& props,
867       OpKernel** kernel) = 0;
868 
869   // Returns true iff the function named `function_name` is stateful.
870   //
871   // NOTE(mrry): This method assumes that the runtime is associated with a
872   // default function library, and looks up `function_name` in that library.
873   // It does not support overriding the function library.
874   virtual bool IsStateful(const std::string& function_name) const = 0;
875 
876   // Returns the device on which the function executes.
877   virtual Device* device() = 0;
878   virtual const Device* device() const = 0;
879 
880   // Returns the default runner in which the ops should be launched. If the
881   // device on which the function executes has a private thread pool, return
882   // runner on the device local thread pool.
883   virtual std::function<void(std::function<void()>)>* runner() = 0;
884 
885   // Get the DeviceMgr from which the device was obtained.
886   virtual const DeviceMgr* device_mgr() const = 0;
887 
888   // Returns the function library definition that backs this runtime.
889   //
890   // NOTE(mrry): The returned library definition is the default function library
891   // for this runtime. The caller may override the function library used by the
892   // runtime to instantiate functions, which will not be reflected in the return
893   // value of this function.
894   virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
895       const = 0;
896 
897   // Returns the environment on which the function executes.
898   virtual Env* env() = 0;
899 
900   // Returns the ConfigProto passed to the session used to create the function.
901   virtual const ConfigProto* const config_proto() = 0;
902 
903   // Returns a debug string showing the definition of the function of
904   // 'handle'.
905   virtual std::string DebugString(Handle handle) = 0;
906 
907   // Returns the graph version number.
908   virtual int graph_def_version() const = 0;
909 
910   typedef uint64 LocalHandle;
911 
912   // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to
913   // the caller), FunctionLibraryRuntime (owned by the returned
914   // ProcessFunctionLibraryRuntime), FunctionLibraryDefinition (transferring
915   // ownership to the caller). Note that both the ProcessFunctionLibraryRuntime
916   // and FunctionLibraryRuntime borrow a pointer to the
917   // FunctionLibraryDefinition and so the FunctionLibraryDefinition should
918   // outlive both.
919   //
920   // The `skip_flib_def` argument controls whether the method should clone the
921   // FunctionLibraryDefinition (default behavior) or return an empty function
922   // library. The latter is used by tf.data, which manages
923   // FunctionLibraryDefinitions for its functions independently (and passes
924   // these into the FunctionLibraryRuntime through an overlay), to avoid linear
925   // runtime w.r.t. to number of functions in the current function library.
926   virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
927                        std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
928                        FunctionLibraryRuntime** out_flr,
929                        bool skip_flib_def = false) = 0;
930 
931   // Returns the name of the executor class (in the sense of
932   // `ExecutorFactory::GetFactory()`) that will be used based on the given
933   // dynamic `options` and static `attrs`. If none is specified, this method
934   // will return an empty string, which leaves the decision up to the runtime.
935   static std::string ExecutorType(const InstantiateOptions& options,
936                                   AttrSlice attrs);
937 };
938 
939 // Returns the device of the `arg_index`-th function input. Update
940 // `composite_devices` if the input device is a composite device.
941 std::string GetFunctionResourceInputDevice(
942     const Tensor& input, const int arg_index, const FunctionDef& function_def,
943     absl::flat_hash_map<string, std::vector<string>>* composite_devices);
944 
945 // Returns a canonicalized string for the instantiation of the function of the
946 // given "name", attributes "attrs", and "options".
947 //
948 // The returned string is guaranteed to be stable within one address space. But
949 // it may be change as the implementation evolves. Therefore, it should not be
950 // persisted or compared across address spaces.
951 std::string Canonicalize(
952     const std::string& funcname, AttrSlice attrs,
953     const FunctionLibraryRuntime::InstantiateOptions& options);
954 std::string Canonicalize(const std::string& funcname, AttrSlice attrs);
955 
956 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
957 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
958 
959 class CustomKernelCreator {
960  public:
~CustomKernelCreator()961   virtual ~CustomKernelCreator() {}
962 
963   // Given a NodeDef 'node_def' and the function library runtime 'flr',
964   // validate if the class supports creating such a kernel.
965   virtual bool CanCreateKernel(
966       const FunctionLibraryRuntime& flr,
967       const std::shared_ptr<const NodeProperties>& props) const = 0;
968 
969   // Given a supported NodeDef, returns a kernel that computes the node.
970   virtual Status CreateKernel(
971       FunctionLibraryRuntime* flr,
972       const std::shared_ptr<const NodeProperties>& props,
973       std::unique_ptr<OpKernel>* kernel) const = 0;
974 };
975 
976 typedef
977 #if !defined(IS_MOBILE_PLATFORM)
978     absl::variant<Tensor, eager::RemoteTensorHandle*>
979         FunctionArg;
980 #else
981     absl::variant<Tensor>
982         FunctionArg;
983 #endif
984 
985 // Either a local tensor or the shape of a remote tensor.
986 typedef absl::variant<Tensor, TensorShape> FunctionRet;
987 
988 // Used to instantiate and run functions in a distributed system.
989 class DistributedFunctionLibraryRuntime {
990  public:
~DistributedFunctionLibraryRuntime()991   virtual ~DistributedFunctionLibraryRuntime() {}
992 
993   // Instantiate a function on a remote target specified in `options.target`, by
994   // sending the name and definition of the function to the remote worker. The
995   // local `handle` is filled for the instantiated function data and can be used
996   // for subsequent run function calls on the remote target.
997   virtual void Instantiate(
998       const std::string& function_name,
999       const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
1000       const FunctionLibraryRuntime::InstantiateOptions& options,
1001       FunctionLibraryRuntime::LocalHandle* handle,
1002       FunctionLibraryRuntime::DoneCallback done) = 0;
1003 
1004   // Run an instantiated remote function (specified by `handle`) with a list of
1005   // input Tensors in `args` and get its output Tensors in `rets`. The input
1006   // tensor data will be sent with the function execution request, and must be
1007   // available on the current caller side.
1008   // opts.runner isn't used for execution.
1009   virtual void Run(const FunctionLibraryRuntime::Options& opts,
1010                    FunctionLibraryRuntime::LocalHandle handle,
1011                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
1012                    FunctionLibraryRuntime::DoneCallback done) = 0;
1013 
1014   // Run an instantiated remote function (specified by `handle`) with a list of
1015   // input Tensors or RemoteTensorHandles as `args` and get its output Tensors
1016   // or TensorShapes in `rets`. When using RemoteTensorHandles as function
1017   // inputs or TensorShapes as outputs, the corresponding tensor data will be
1018   // resolved on the remote worker, so it is not required to be locally
1019   // available on the caller side. Using RemoteTensorHandle inputs is not
1020   // supported in TensorFlow v1 runtime.
1021   virtual void Run(const FunctionLibraryRuntime::Options& opts,
1022                    FunctionLibraryRuntime::LocalHandle handle,
1023                    gtl::ArraySlice<FunctionArg> args,
1024                    std::vector<FunctionRet>* rets,
1025                    FunctionLibraryRuntime::DoneCallback done) = 0;
1026 
1027   // Clean up a previously instantiated function on remote worker.
1028   virtual void CleanUp(uint64 step_id,
1029                        FunctionLibraryRuntime::LocalHandle handle,
1030                        FunctionLibraryRuntime::DoneCallback done) = 0;
1031 
1032   // DeviceMgr with *all* available devices (i.e., local and remote).
1033   virtual DeviceMgr* remote_device_mgr() const = 0;
1034 };
1035 
1036 // Extracts the actual type from "attr_values" based on its definition
1037 // "arg_def".
1038 //
1039 // If "arg_def" is a N*T type, *is_type_list is set to false, and
1040 // *dtypes is set to be a vector of size N and each element is T.
1041 //
1042 // If "arg_def" is a list(type), *is_type_list is set to true, and
1043 // *dtypes is set to be a vector of types specified in attrs for
1044 // arg_def.
1045 //
1046 // Otherwise (arg_def is a simple type T), *is_type_list is set to
1047 // false, and *dtypes is set to a single element vector, whose only
1048 // element is T.
1049 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
1050                   bool* is_type_list, DataTypeVector* dtypes);
1051 
1052 // To register a gradient function for a builtin op, one should use
1053 //   REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
1054 //
1055 // Typically, the c++ grad factory is a plan function that can be
1056 // converted into ::tensorflow::gradient::Creator, which is
1057 //   std::function<Status(const AttrSlice&, FunctionDef*)>.
1058 //
1059 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
1060 // definition of a brain function which compute the gradient for the
1061 // <op_name> when the <op_name> is instantiated with the given attrs.
1062 //
1063 // E.g.,
1064 //
1065 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
1066 //   bool transpose_a;
1067 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
1068 //   bool transpose_b;
1069 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
1070 //   DataType dtype;
1071 //   TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
1072 //   if (!transpose_a && !transpose_b) {
1073 //     *g = FunctionDefHelper::Define(
1074 //       "MatMulGrad",
1075 //       {"x:T ", "y:T", "dz:T"},    // Inputs to this function
1076 //       {"dx:T", "dy:T"},           // Outputs from this function
1077 //       {"T: {float, double}"},     // Attributes needed by this function
1078 //       {
1079 //         {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
1080 //         {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
1081 //         {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
1082 //         {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
1083 //       });
1084 //   } else {
1085 //     ... ...
1086 //   }
1087 //   return Status::OK();
1088 // }
1089 //
1090 // NOTE: $T is substituted with the type variable "T" when the
1091 // gradient function MatMul is instantiated.
1092 //
1093 // TODO(zhifengc): Better documentation somewhere.
1094 
1095 // Macros to define a gradient function factory for a primitive
1096 // operation.
1097 #define REGISTER_OP_GRADIENT(name, fn) \
1098   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
1099 
1100 #define REGISTER_OP_NO_GRADIENT(name) \
1101   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
1102 
1103 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
1104   REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
1105 
1106 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)      \
1107   static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \
1108       SHOULD_REGISTER_OP_GRADIENT &&                  \
1109       ::tensorflow::gradient::RegisterOp(name, fn)
1110 
1111 namespace gradient {
1112 // Register a gradient creator for the "op".
1113 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
1114 bool RegisterOp(const std::string& op, Creator func);
1115 
1116 // Returns OK the gradient creator for the "op" is found (may be
1117 // nullptr if REGISTER_OP_NO_GRADIENT is used.
1118 Status GetOpGradientCreator(const std::string& op, Creator* creator);
1119 };  // namespace gradient
1120 
1121 // Declare explicit instantiations of GetAttr
1122 #define GET_ATTR(T)                                          \
1123   extern template Status FunctionLibraryDefinition::GetAttr( \
1124       const Node&, const string&, T*) const;                 \
1125   extern template Status FunctionLibraryDefinition::GetAttr( \
1126       const NodeDef&, const string&, T*) const;
1127 GET_ATTR(string)
1128 GET_ATTR(bool)
1129 #undef GET_ATTR
1130 
1131 }  // end namespace tensorflow
1132 
1133 #endif  // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
1134