• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
18 
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/platform/platform.h"
24 // clang-format on
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/variant.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/registration/registration.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/lib/gtl/flatmap.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/protobuf.h"
44 #include "tensorflow/core/platform/threadpool_interface.h"
45 #include "tensorflow/core/protobuf/config.pb.h"
46 #if !defined(IS_MOBILE_PLATFORM)
47 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
48 #endif  // IS_MOBILE_PLATFORM
49 
50 namespace tensorflow {
51 
52 class CancellationManager;
53 class CollectiveExecutor;
54 class DeviceSet;
55 class Graph;
56 class GraphDef;
57 class OpKernel;
58 class ProcessFunctionLibraryRuntime;
59 class ResourceMgr;
60 class Rendezvous;
61 class ScopedStepContainer;
62 class StepStatsCollectorInterface;
63 class Node;
64 
65 // FunctionDefHelper::Create is a convenient helper to construct a
66 // FunctionDef proto.
67 // E.g.,
68 //   FunctionDef my_func = FunctionDefHelper::Create(
69 //     "my_func_name",
70 //     {"x:T", "y:T" /* one string per argument */},
71 //     {"z:T" /* one string per return value */},
72 //     {"T: {float, double}" /* one string per attribute  */},
73 //     {
74 //        {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
75 //        /* one entry per function node */
76 //     },
77 //     /* Mapping between function returns and function node outputs. */
78 //     {{"z", "o:z"}});
79 //
80 // For the old Function::Node approach, use FunctionDefHelper::Define()
81 // E.g.,
82 //   FunctionDef my_func = FunctionDefHelper::Define(
83 //     "my_func_name",
84 //     {"x:T", "y:T" /* one string per argument */},
85 //     {"z:T" /* one string per return value */},
86 //     {"T: {float, double}" /* one string per attribute  */},
87 //     {
88 //        {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
89 //        /* one entry per function node */
90 //     });
91 class FunctionDefHelper {
92  public:
93   // AttrValueWrapper has copy constructors for the type T so that
94   // it's easy to construct a simple AttrValue proto.
95   //
96   // If T is a string type (const char*, string, or StringPiece), and
97   // it starts with "$", we construct a AttrValue of "placeholder".
98   //
99   // E.g.,
100   //   std::<string, AttrValueWrapper> x = {"T", "$T"}
101   // is a named attr value placeholder.
102   struct AttrValueWrapper {
103     AttrValue proto;
104 
AttrValueWrapperAttrValueWrapper105     AttrValueWrapper() {}
106 
107     template <typename T>
AttrValueWrapperAttrValueWrapper108     AttrValueWrapper(T val) {  // NOLINT(runtime/explicit)
109       SetAttrValue(val, &proto);
110     }
111 
112    private:
113     void InitFromString(StringPiece val);
114   };
115 
116   // Constructs an AttrValue.func given the "name" and "attrs".
117   static AttrValueWrapper FunctionRef(
118       const std::string& name,
119       gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
FunctionRef(const std::string & name)120   static AttrValueWrapper FunctionRef(const std::string& name) {
121     return FunctionRef(name, {});
122   }
123 
124   // Node is used to construct FunctionDef.Node using initialization
125   // lists. E.g.,
126   //  Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}};  // z = x * y
127   //
128   // If the op has no inputs, then name is be specified.
129   //  Node n = {{}, "AssignVariable", {"resource", "val"}, {{"dtype",
130   //  "DT_FLOAT"},
131   //            {"update0"}, "CPU:0", "update1"}}
132   struct Node {
133     // When constructing a NodeDef, the first entry in ret is used as
134     // the node name, the remaining values are ignored.
135     std::vector<string> ret;
136     std::string op;
137     std::vector<string> arg;
138     std::vector<std::pair<string, AttrValueWrapper>> attr;
139     std::vector<string> dep;
140     std::string device;
141 
142     // Required if the op has zero outputs. Otherwise, ret[0] used as name if
143     // name is left empty.
144     std::string name;
145 
GetNameNode146     std::string GetName() const {
147       if (!name.empty()) return name;
148       CHECK(!ret.empty());
149       return ret[0];
150     }
151     std::vector<string> original_node_names;
152     std::vector<string> original_func_names;
153 
154     NodeDef ToNodeDef() const;
155   };
156 
157   // Creates a FunctionDef from the given parameters. Node inputs must use
158   // function encoding (node_name:output_name[:output_index]).
159   // - `ret_def` holds a mapping from the function output names from `out_def`
160   //   to the node outputs from `node_def`.
161   // - `control_ret_def` holds a mapping from the function control
162   //   output names to the nodes from `node_def`.
163   static FunctionDef Create(
164       const std::string& function_name, gtl::ArraySlice<string> in_def,
165       gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
166       gtl::ArraySlice<Node> node_def,
167       gtl::ArraySlice<std::pair<string, string>> ret_def,
168       gtl::ArraySlice<std::pair<string, string>> control_ret_def);
169 
170   // Creates a FunctionDef from the given parameters. Node inputs must use
171   // function encoding (node_name:output_name[:output_index]).
172   // - `ret_def` holds a mapping from the function output names from `out_def`
173   //   to the node outputs from `node_def`.
174   static FunctionDef Create(const std::string& function_name,
175                             gtl::ArraySlice<string> in_def,
176                             gtl::ArraySlice<string> out_def,
177                             gtl::ArraySlice<string> attr_def,
178                             gtl::ArraySlice<Node> node_def,
179                             gtl::ArraySlice<std::pair<string, string>> ret_def);
180 
181   // TODO(josh11b): Get rid of these and transition to the one above.
182   static FunctionDef Define(const std::string& function_name,
183                             gtl::ArraySlice<string> arg_def,
184                             gtl::ArraySlice<string> ret_def,
185                             gtl::ArraySlice<string> attr_def,
186                             gtl::ArraySlice<Node> node_def);
187 
188   // Defines an anonymous function. I.e., its name is not relevant.
189   static FunctionDef Define(gtl::ArraySlice<string> arg_def,
190                             gtl::ArraySlice<string> ret_def,
191                             gtl::ArraySlice<string> attr_def,
192                             gtl::ArraySlice<Node> node_def);
193 
194   // Helpers to construct a constant scalar.
195   template <typename T>
Const(const std::string & name,const T & val)196   static Node Const(const std::string& name, const T& val) {
197     Node n = {{name}, "Const"};
198     const DataType dtype = DataTypeToEnum<T>::value;
199     n.attr.push_back({"dtype", dtype});
200     Tensor t(dtype, TensorShape({}));
201     t.scalar<T>()() = val;
202     n.attr.push_back({"value", t});
203     return n;
204   }
205 
206   template <typename T>
Const(const std::string & name,gtl::ArraySlice<T> vals)207   static Node Const(const std::string& name, gtl::ArraySlice<T> vals) {
208     Node n = {{name}, "Const"};
209     const DataType dtype = DataTypeToEnum<T>::value;
210     n.attr.push_back({"dtype", dtype});
211     int64_t num = vals.size();
212     Tensor t(dtype, TensorShape({num}));
213     for (size_t i = 0; i < vals.size(); ++i) {
214       t.flat<T>()(i) = vals[i];
215     }
216     n.attr.push_back({"value", t});
217     return n;
218   }
219 };
220 
221 template <>
AttrValueWrapper(const char * val)222 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
223   InitFromString(val);
224 }
225 
226 template <>
AttrValueWrapper(const std::string & val)227 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
228     const std::string& val) {
229   InitFromString(val);
230 }
231 
232 template <>
AttrValueWrapper(StringPiece val)233 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
234   InitFromString(val);
235 }
236 
237 // Instantiate a function.
238 //
239 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
240 // containing placeholders.  InstantiateFunction binds these
241 // placeholders and produces an instantiated function encoded in
242 // "result.gdef". The value to substitute a placeholder is given by
243 // "attr_values", which is a map from a placeholder name to an attr
244 // value.
245 //
246 // InstantiateFunction calls "get_function" to find signatures of other
247 // functions and primitive ops.
248 
249 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
250 // and opdef is filled with a pointer to the corresponding signature
251 // (a OpDef proto). Otherwise, returns an error.
252 typedef std::function<Status(const string&, const OpDef**)>
253     GetFunctionSignature;
254 
255 struct InstantiationResult {
256   DataTypeVector arg_types;
257   DataTypeVector ret_types;
258   std::vector<NodeDef> nodes;
259 };
260 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
261                            GetFunctionSignature get_function,
262                            InstantiationResult* result);
263 
264 // Returns a debug string for a function definition.
265 //
266 // The returned text is multiple-line. It is intended to be
267 // human-readable rather than being friendly to parsers. It is _NOT_
268 // intended to be the canonical string representation of "func_def".
269 // Particularly, it may not include all information presented in
270 // "func_def" (e.g., comments, description of the function arguments,
271 // etc.)
272 std::string DebugString(const FunctionDef& func_def);
273 std::string DebugString(const GraphDef& instantiated_func_def);
274 std::string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
275 
276 // Returns a debug string for a top level graph (the main program and
277 // its supporting functions defined in its library).
278 std::string DebugStringWhole(const GraphDef& gdef);
279 
280 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
281 // of NodeDefs doesn't matter.
282 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
283 
284 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
285 // In other words, if two fdefs compare equal, their hash values will be the
286 // same.
287 uint64 FunctionDefHash(const FunctionDef& fdef);
288 
289 class CallFrameInterface {
290  public:
~CallFrameInterface()291   virtual ~CallFrameInterface() {}
292 
293   virtual size_t num_args() const = 0;
294   virtual size_t num_retvals() const = 0;
295 
296   virtual Status GetArg(int index, const Tensor** val) = 0;
297 
298   // Optimized implementation of `GetArg()` that allows the caller to take
299   // ownership of the tensor. This method may only be called once per
300   // value of `index` and `CallFrameInterface` instance.
301   //
302   // REQUIRES: `this->CanConsumeArg(index) == true`.
ConsumeArg(int index,Tensor * val)303   virtual void ConsumeArg(int index, Tensor* val) {
304     LOG(ERROR) << "This `CallFrameInterface` implementation does not support "
305                   "consuming arguments.";
306   }
CanConsumeArg(int index)307   virtual bool CanConsumeArg(int index) const { return false; }
308 
309   virtual Status SetRetval(int index, const Tensor& val) = 0;
310 };
311 
312 // Represents a function call frame. I.e., the data structure used to
313 // pass arguments to a function and retrieve its results.
314 //
315 // Runtime must arrange accesses to one FunctionCallFrame s.t.
316 //   1. SetArgs() happens before any GetArg();
317 //   2. GetRetvals happens after all SetRetval();
318 class FunctionCallFrame : public CallFrameInterface {
319  public:
320   FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
321   ~FunctionCallFrame() override;
322 
323   // Caller methods.
324   Status SetArgs(gtl::ArraySlice<Tensor> args);
325   Status GetRetvals(std::vector<Tensor>* rets) const;
326 
327   // Moves the return values from the frame to rets. If allow_dead_tensors is
328   // false it will fail if any of the retvals do not have a value.
329   Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors);
330 
num_args()331   size_t num_args() const override { return arg_types_.size(); }
num_retvals()332   size_t num_retvals() const override { return ret_types_.size(); }
333 
334   // Callee methods.
335   Status GetArg(int index, const Tensor** val) override;
336   Status SetRetval(int index, const Tensor& val) override;
337 
338  private:
339   DataTypeVector arg_types_;
340   DataTypeVector ret_types_;
341   gtl::InlinedVector<Tensor, 4> args_;
342   struct Retval {
343     bool has_val = false;
344     Tensor val;
345   };
346   gtl::InlinedVector<Retval, 4> rets_;
347 
348   TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
349 };
350 
351 // Language agnostic stack traces.
352 class AbstractStackTrace {
353  public:
354   struct TracePrintingOptions {
355     // Show inline the contents of each stack line.
356     bool show_line_contents = false;
357 
358     // Drop the common largest prefix of all filenames in stack frames.
359     bool filter_common_prefix = false;
360 
361     // Do not show internal frames.
362     bool drop_internal_frames = false;
363   };
364 
~AbstractStackTrace()365   virtual ~AbstractStackTrace() {}
366 
367   // The returned span is alive as long as the AbstractStackTrace is alive.
368   virtual absl::Span<StackFrame const> ToFrames() const = 0;
369 
370   // Returns the last stack frame from user code, attempting to ignore the
371   // framework code. Returns an empty frame if no such stack frame was found.
372   virtual StackFrame LastUserFrame() const = 0;
373   virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
374 };
375 
376 using StackTracesMap =
377     std::unordered_map<std::string,
378                        std::shared_ptr<tensorflow::AbstractStackTrace>>;
379 
380 // Helper to maintain a map between function names in a given
381 // FunctionDefLibrary and function definitions.
382 //
383 // This class is thread-safe.
384 class FunctionLibraryDefinition : public OpRegistryInterface {
385  public:
386   // Ops created for function arguments bear the name given by `kArgOp`; those
387   // created for return values bear the name given by `kRetOp`.
388   static constexpr const char* const kArgOp = "_Arg";
389   static constexpr const char* const kDeviceArgOp = "_DeviceArg";
390   static constexpr const char* const kRetOp = "_Retval";
391   static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
392   static constexpr const char* const kIntsOnDeviceAttr =
393       "experimental_ints_on_device";
394   static constexpr const char* const kSharedRendezvousAttr =
395       "shared_rendezvous";
396 
397   static constexpr const char* const kGradientOp = "SymbolicGradient";
398   static constexpr const char* const kFuncAttr = "f";
399 
400   // Note: This constructor grabs `lib_def`'s lock in shared mode.
401   FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
402   FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
403                             const FunctionDefLibrary& lib_def = {});
404   ~FunctionLibraryDefinition() override;
405 
406   FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
407       delete;
408 
409   // Returns True if the library contains `func`, False otherwise.
410   bool Contains(const std::string& func) const;
411 
412   // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
413   // returns its definition proto.
414   //
415   // NB: This function returns a borrowed pointer, which can be invalidated by a
416   // subsequent call to `ReplaceFunction()` with the given name.
417   const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_);
418 
419   // Adds function definition 'fdef' to this function library.
420   // Returns status 'ok' on success, or error otherwise. This is a no-op if
421   // 'fdef' already exists in this function library.
422   // If 'fdef' is successfully added to the library, it will be accessible
423   // from 'LookUp' and included in the proto returned by 'ToProto'.
424   // This operation is atomic.
425   //
426   // Associates `graph` with a function `func_name`. Lifetime assumption:
427   // `graph` has to outlive all instantiated graphs.
428   Status AddFunctionDef(const FunctionDef& fdef,
429                         const StackTracesMap& stack_traces = {})
430       TF_LOCKS_EXCLUDED(mu_);
431 
432   // Adds gradient definition 'grad' to this function library.
433   // This is a no-op if 'grad' already exists in this function library.
434   // If 'grad' is successfully added, it will be accessible via 'FindGradient'
435   // and included in the proto returned by 'ToProto'.
436   // This operation is atomic.
437   Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
438 
439   // Replaces the function corresponding to `func` with `fdef`. Returns
440   // a non-OK status if "func" was not found in the library, OK otherwise.
441   // Please be careful when replacing function: make sure all previous pointers
442   // returned by `Find()` are no longer in use.
443   Status ReplaceFunction(const std::string& func, const FunctionDef& fdef,
444                          const StackTracesMap& stack_traces = {})
445       TF_LOCKS_EXCLUDED(mu_);
446 
447   // Replaces the gradient corresponding to `grad.function_name()`. Returns
448   // a non-OK status if "grad.function_name()" was not found in the library, OK
449   // otherwise.
450   Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
451 
452   // Removes the function corresponding to 'func'. Returns a non-OK status if
453   // 'func' was not found in the library, OK otherwise.
454   // Please be careful when removing function: make sure there are no other
455   // nodes using the function, and all previous pointers returned by `Find()`
456   // are no longer in use.
457   Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
458 
459   // Removes all the functions and gradient functions.
460   void Clear() TF_LOCKS_EXCLUDED(mu_);
461 
462   // Adds the functions and gradients in 'other' to this function library.
463   // Duplicate functions and gradients are ignored.
464   // This operation is atomic.
465   Status AddLibrary(const FunctionLibraryDefinition& other)
466       TF_LOCKS_EXCLUDED(mu_);
467 
468   // Adds the functions and gradients in 'lib_def' to this function library.
469   // Duplicate functions and gradients are ignored.
470   // This operation is atomic.
471   Status AddLibrary(const FunctionDefLibrary& lib_def) TF_LOCKS_EXCLUDED(mu_);
472 
473   // If the gradient function for 'func' is specified explicitly in
474   // the library, returns the gradient function name.  Otherwise,
475   // returns an empty string.
476   std::string FindGradient(const std::string& func) const
477       TF_LOCKS_EXCLUDED(mu_);
478 
479   // OpRegistryInterface method. Useful for constructing a Graph.
480   //
481   // If "op" is defined in the library, returns its signature.
482   // Otherwise, assume "op" is a primitive op and returns its op
483   // signature and shape inference function.
484   //
485   // NB: This function outputs a borrowed pointer, which can be invalidated by a
486   // subsequent call to `ReplaceFunction()` with the given name.
487   Status LookUp(const std::string& op_type_name,
488                 const OpRegistrationData** op_reg_data) const override
489       TF_LOCKS_EXCLUDED(mu_);
490 
491   // Generates new function name with the specified prefix that is unique
492   // across this library.
493   std::string UniqueFunctionName(StringPiece prefix) const
494       TF_LOCKS_EXCLUDED(mu_);
495 
496   // Given a node def 'ndef', inspects attributes of the callee
497   // function to derive the attribute 'value' for 'attr'. Returns OK
498   // iff the attribute is given by the function's definition.
499   // TODO(irving): Remove; keep only the const Node& version.
500   template <typename T>
501   Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const;
502 
503   // Given a node, inspects attributes of the callee function to derive the
504   // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
505   // function's definition.
506   template <typename T>
507   Status GetAttr(const Node& node, const std::string& attr, T* value) const;
508 
509   // Returns a proto representation of the state of this function library.
510   FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_);
511 
num_functions()512   size_t num_functions() const {
513     tf_shared_lock l(mu_);
514     return function_defs_.size();
515   }
516 
517   // Returns all the function names in the FunctionLibraryDefinition.
518   std::vector<string> ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_);
519 
default_registry()520   const OpRegistryInterface* default_registry() const {
521     return default_registry_;
522   }
set_default_registry(const OpRegistryInterface * registry)523   void set_default_registry(const OpRegistryInterface* registry) {
524     default_registry_ = registry;
525   }
526 
527   // Returns a copy of `*this` with only the subset of functions that are
528   // reachable from the nodes of `graph` or `func`.
529   FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
530   FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
531 
532   // Copies the function named `func` from `other` to this
533   // FunctionLibraryDefinition.
534   // REQUIRES: `this->default_registry() == other.default_registry()`.
535   // Returns OK on success, or error otherwise. This is a no-op if a function
536   // name `func` already exists in this function library, and has the same
537   // implementation as in `other`. If the implementations conflict, an invalid
538   // argument error is returned.
539   Status CopyFunctionDefFrom(const std::string& func,
540                              const FunctionLibraryDefinition& other)
541       TF_LOCKS_EXCLUDED(mu_);
542 
543   // Returns graph with debug stack traces for the given function, or `nullptr`
544   // if none found.
GetStackTraces(const std::string & func_name)545   const StackTracesMap& GetStackTraces(const std::string& func_name) const {
546     tf_shared_lock l(mu_);
547     std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name);
548     if (entry) {
549       return entry->stack_traces;
550     }
551     static const auto* empty_map = new StackTracesMap;
552     return *empty_map;
553   }
554 
555  private:
556   // Shape inference for functions is handled separately by ShapeRefiner.
557 
558   struct FunctionDefAndOpRegistration {
559     explicit FunctionDefAndOpRegistration(
560         const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {});
561 
562     const FunctionDef fdef;
563     const OpRegistrationData op_registration_data;
564     const StackTracesMap stack_traces;
565   };
566 
567   std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
568       const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_);
569   std::string FindGradientHelper(const std::string& func) const
570       TF_SHARED_LOCKS_REQUIRED(mu_);
571 
572   Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
573                    bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
574 
575   // Same as AddFunctionDef/AddGradientDef except these methods set
576   // `added` to true if the `fdef`/`grad` were actually added to this.
577   Status AddFunctionDefHelper(const FunctionDef& fdef,
578                               const StackTracesMap& stack_traces, bool* added)
579       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
580   Status AddGradientDefHelper(const GradientDef& grad, bool* added)
581       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
582 
583   // Helper function for GetAttr. Returns the FunctionDef* to get the
584   // attr from.
585   const FunctionDef* GetAttrImpl(const NodeDef& ndef) const
586       TF_LOCKS_EXCLUDED(mu_);
587 
588   // Remove all functions in `funcs` and all gradients of functions in
589   // `funcs_with_grads` from this library.
590   Status Remove(const std::vector<string>& funcs,
591                 const std::vector<string>& funcs_with_grads)
592       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
593 
594   // Remove `func` from the library. Returns non-OK Status unless `func` is in
595   // the library. This should only be called when there is a guarantee that the
596   // function being removed hasn't been retrieved with `Find`.
597   Status RemoveFunctionHelper(const std::string& func)
598       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
599 
600   // Remove gradient of function `func` from the library. Returns non-OK Status
601   // unless `func` has a gradient.
602   Status RemoveGradient(const std::string& func)
603       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
604 
605   mutable mutex mu_;
606   const OpRegistryInterface* default_registry_;
607   gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>>
608       function_defs_ TF_GUARDED_BY(mu_);
609   gtl::FlatMap<string, string> func_grad_ TF_GUARDED_BY(mu_);
610 };
611 
612 // Forward declare. Defined in common_runtime/function.h
613 struct FunctionBody;
614 
615 // Forward declare. Defined in common_runtime/device.h
616 class Device;
617 // Forward declare. Defined in common_runtime/device_mgr.h
618 class DeviceMgr;
619 
620 // Index of an _Arg node.
621 struct FunctionArgIndex {
FunctionArgIndexFunctionArgIndex622   explicit FunctionArgIndex(const int index) : index(index) {}
FunctionArgIndexFunctionArgIndex623   FunctionArgIndex(const int index, const int sub_index)
624       : index(index), sub_index(sub_index) {}
625 
626   // The value of the attribute "Index" of the _Arg node.
627   int index;
628   // Set only when the _Arg node represents multiple arguments (e.g. an _Arg
629   // node is replicated to multiple devices/subgraphs). Use sub-index to
630   // distinguish arguments with the same index.
631   int sub_index = -1;
632 };
633 
634 class FunctionLibraryRuntime {
635  public:
~FunctionLibraryRuntime()636   virtual ~FunctionLibraryRuntime() {}
637 
638   // Instantiate a function with the given "attrs".
639   //
640   // Returns OK and fills in "handle" if the instantiation succeeds.
641   // Otherwise returns an error and "handle" is undefined.
642   struct InstantiateOptions {
643     // The canonical device name of the device on which the function
644     // should be instantiated. If empty, the function will be
645     // instantiated on the local device.
646     std::string target;
647 
648     // Should the function be instantiated as a multi-device function?
649     bool is_multi_device_function = false;
650 
651     // If true, graph passes will be skipped when instantiating the function
652     // since they have already run on the main function side.
653     bool is_component_function = false;
654 
655     // For multi-device functions, a vector of canonical device names for
656     // function's inputs. The device of resource inputs must be the device
657     // backing the resource, not the CPU device backing the resource handle.
658     // Must have the same length as number of inputs to the function.
659     std::vector<string> input_devices;
660 
661     // For multi-device functions, a vector of canonical device names for
662     // function's outputs.
663     //
664     // (a) If specified (must have the same length as number of outputs):
665     //
666     // Specified devices will be assigned to Retval nodes inserted into the
667     // function body graph in place of function outputs. It is allowed to
668     // specify output device as empty string, in this case Retval device
669     // assignment will be inferred later when function graph will be placed
670     // before partitioning (this is required for resource outputs). Placer will
671     // respect colocation constraints.
672     //
673     // (b) If not specified:
674     //
675     // Function runtime will infer Retval device by following input edges, until
676     // it will reach a node with a device specification. This device
677     // specification must identify a unique device, i.e. a general specification
678     // like "job:foo" matching multiple devices will result in an error.
679     //
680     // IMPORTANT: Resource outputs
681     //
682     // Multi device functions might return resources on a devices different from
683     // the function call device. If output device is not specified for the
684     // resource output, and node producing that resource is a function call,
685     // runtime will leave device specification empty and will rely on Placer to
686     // infer correct device.
687     std::vector<string> output_devices;
688 
689     // If set, it indicates the original output indices of a component function.
690     absl::optional<std::vector<int>> ret_indices = absl::nullopt;
691 
692     // Maps from a CompositeDevice name to a list of underlying physical
693     // devices.
694     absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
695 
696     // This interface is EXPERIMENTAL and subject to change.
697     //
698     // For multi-device functions, a mapping from _Arg node index to type and
699     // shape for input resources.
700     // REQUIRES: if input_resource_dtypes_and_shapes.count(i) > 0 then i-th
701     // argument type must be DT_RESOURCE.
702     std::unordered_map<int, DtypeAndPartialTensorShape>
703         input_resource_dtypes_and_shapes;
704 
705     // This interface is EXPERIMENTAL and subject to change.
706     //
707     // If non-null, the runtime will use `lib_def` to resolve function(s) named
708     // in `function_name` and `attrs`. Otherwise, the runtime will use its
709     // internal library.
710     //
711     // NOTE(mrry): If provided, all functions defined in `lib_def` must be
712     // self-contained, and cannot refer to functions defined in other libraries.
713     const FunctionLibraryDefinition* lib_def = nullptr;
714 
715     // This interface is EXPERIMENTAL and subject to change.
716     //
717     // If non-empty, the runtime will use `state_handle` to identify
718     // cached state related the instantiated function. Two functions
719     // of the same name and attrs, instantiated with the same
720     // `state_handle` will have the same handle and share the same
721     // state (in stateful kernels); and two functions with different
722     // values for `state_handle` will have independent state.
723     std::string state_handle;
724 
725     // This interface is EXPERIMENTAL and subject to change.
726     //
727     // Instantiates the function using an executor of the given type. If empty,
728     // the default TensorFlow executor will be used.
729     std::string executor_type;
730 
731     // If true, the runtime will attempt to create kernels for the function at
732     // instantiation time, rather than on the first run. This can be used to
733     // surface errors earlier.
734     bool create_kernels_eagerly = false;
735 
736     // This interface is EXPERIMENTAL and subject to change.
737     //
738     // Instantiates the function with the provided config_proto.
739     ConfigProto config_proto;
740 
741     // If provided, this optimization function will be invoked before
742     // the placer for multi-device functions.
743     std::function<Status(std::vector<string> /*ret_node_names*/,
744                          std::vector<string> /*keep_node_names*/,
745                          FunctionLibraryDefinition*, const DeviceSet&,
746                          Device* /*cpu_device*/, std::unique_ptr<Graph>*)>
747         optimize_graph_fn;
748 
749     // If set, partitioned functions will be added to `graph_collector`.
750     // `graph_collector` must be alive during the call to Instantiate.
751     GraphCollector* graph_collector = nullptr;
752 
753     // Indicates whether the multi-device function backend should default the
754     // placement of ops without request device to `target`.
755     bool default_device_to_target = true;
756 
757     // If true, the optimized Graph will be stored so that
758     // `FunctionLibraryRuntime::DebugString(handle)` contains the optimized
759     // Graph. Otherwise, the unoptimized function Graph will be returned.
760     bool include_optimized_graph_in_debug_string = false;
761 
762     // If true, the function library runtime cache the function instantiation.
763     bool use_function_cache = false;
764 
765     // This interface is EXPERIMENTAL and subject to change.
766     //
767     // If True, allow optimizations which should be targeted at a limited
768     // set of small functions.  For example, running kernels synchronously can
769     // be faster under some conditions.
770     bool allow_small_function_optimizations = false;
771 
772     // This interface is EXPERIMENTAL and subject to change.
773     //
774     // If True, allow graphs containing control flow nodes to be run on the
775     // single threaded executor.
776     bool allow_control_flow_sync_execution = false;
777 
778     // TODO(b/176491312): Remove this if shape inference on import flag is
779     // removed. If True, allows mlir roundtrip to run shape inference on import.
780     bool shape_inference_on_tfe_dialect_import = true;
781 
782     // Force int32 _Arg and _Retvals nodes to be left on device instead of
783     // pinning to host.
784     //
785     // Note that we do not pin int32 nodes to host for subgraphs running in
786     // TPU/XLA devices. So this is mainly used to handle the case of multi-CPU
787     // and GPU (non-XLA) graphs.
788     bool int_args_and_retvals_on_device = false;
789 
790     // This interface is EXPERIMENTAL and subject to change.
791     //
792     // Instantiates the function for XLA compilation on device_type. If empty,
793     // function is not compiled.
794     std::string xla_compile_device_type;
795   };
796   typedef uint64 Handle;
797   virtual Status Instantiate(const std::string& function_name, AttrSlice attrs,
798                              const InstantiateOptions& options,
799                              Handle* handle) = 0;
Instantiate(const std::string & function_name,AttrSlice attrs,Handle * handle)800   Status Instantiate(const std::string& function_name, AttrSlice attrs,
801                      Handle* handle) {
802     auto opts = absl::make_unique<InstantiateOptions>();
803     return Instantiate(function_name, attrs, *opts, handle);
804   }
805 
806   // Releases state associated with the handle.
807   virtual Status ReleaseHandle(Handle handle) = 0;
808 
809   // Returns the function body for the instantiated function given its
810   // handle 'h'. Returns nullptr if "h" is not found.
811   //
812   // *this keeps the ownership of the returned object, which remains alive
813   // as long as *this.
814   virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
815 
816   // Returns the return types for the function identified by handle `h`.
817   virtual Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0;
818 
819   // Asynchronously invokes the instantiated function identified by
820   // "handle".
821   //
822   // If function execution succeeds, "done" is called with OK and
823   // "*rets" is filled with the function's return values. Otherwise,
824   // "done" is called with an error status.
825   //
826   // Does not take ownership of "rets".
827   // In the cross-process scenario, runner isn't used for making the Async
828   // RPC calls.
829   struct Options {
OptionsOptions830     Options() {}
OptionsOptions831     explicit Options(const int64_t step_id) : step_id(step_id) {}
832     // Choose a step ID that is guaranteed not to clash with any
833     // Session-generated step ID. DirectSession only generates
834     // non-negative step IDs (contiguous, starting from 0), and
835     // MasterSession generates 56-bit random step IDs whose MSB is
836     // always 0, so a negative random step ID should suffice.
837     const int64_t step_id = -std::abs(static_cast<int64_t>(random::New64()));
838 
839     // op_id of the function running in eager mode. Set when we want to copy
840     // remote outputs lazily. All components of a remote multi-device function
841     // should use the same op_id, in order to correctly map remote output
842     // tensors to the remote TensorHandles in the default device.
843     absl::optional<int64_t> op_id = absl::nullopt;
844 
845     RendezvousInterface* rendezvous = nullptr;
846     CancellationManager* cancellation_manager = nullptr;
847     CollectiveExecutor* collective_executor = nullptr;
848     ScopedStepContainer* step_container = nullptr;
849     StepStatsCollectorInterface* stats_collector = nullptr;
850     CoordinationServiceAgent* coordination_service_agent = nullptr;
851 
852     absl::optional<ManagedStackTrace> stack_trace = absl::nullopt;
853 
854     std::function<void(std::function<void()>)>* runner = nullptr;
855 
856     // Parameters for remote function execution.
857     bool remote_execution = false;
858     std::string source_device = "";  // Fully specified device name.
859 
860     // Allocator attributes specifying where the args are / rets should be put.
861     // These should either be {} or match the length of args / retvals. If {},
862     // the default allocator attributes will be assumed for all args / retvals.
863     std::vector<AllocatorAttributes> args_alloc_attrs;
864     std::vector<AllocatorAttributes> rets_alloc_attrs;
865 
866     // If true, we create a new IntraProcessRendezvous, else use the existing
867     // one.
868     bool create_rendezvous = false;
869 
870     // If True, allow returning dead tensors.
871     bool allow_dead_tensors = false;
872 
873     // If True, hint that all kernels should be treated as "inexpensive", and
874     // hence executed on the scheduling thread.
875     bool run_all_kernels_inline = false;
876 
877     // If not null, use this thread pool for intra op scheduling.
878     thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
879 
880     // Returns a human readable representation of this.
881     std::string DebugString() const;
882   };
883   typedef std::function<void(const Status&)> DoneCallback;
884   virtual void Run(const Options& opts, Handle handle,
885                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
886                    DoneCallback done) = 0;
887   virtual void Run(const Options& opts, Handle handle,
888                    CallFrameInterface* call_frame, DoneCallback done) = 0;
889 
890   virtual Status RunSync(Options opts, Handle handle,
891                          gtl::ArraySlice<Tensor> args,
892                          std::vector<Tensor>* rets) = 0;
893   virtual Status RunSync(Options opts, Handle handle,
894                          CallFrameInterface* call_frame) = 0;
895 
896   // Creates a "kernel" for the given NodeProperties "props".
897   //
898   // If succeeds, returns OK and the caller takes the ownership of the
899   // returned "*kernel". Otherwise, returns an error.
900   virtual Status CreateKernel(
901       const std::shared_ptr<const NodeProperties>& props,
902       OpKernel** kernel) = 0;
903 
904   // Returns true iff the function named `function_name` is stateful.
905   //
906   // NOTE(mrry): This method assumes that the runtime is associated with a
907   // default function library, and looks up `function_name` in that library.
908   // It does not support overriding the function library.
909   virtual bool IsStateful(const std::string& function_name) const = 0;
910 
911   // Returns the device on which the function executes.
912   virtual Device* device() = 0;
913   virtual const Device* device() const = 0;
914 
915   // Returns the default runner in which the ops should be launched. If the
916   // device on which the function executes has a private thread pool, return
917   // runner on the device local thread pool.
918   virtual std::function<void(std::function<void()>)>* runner() = 0;
919 
920   // Get the DeviceMgr from which the device was obtained.
921   virtual const DeviceMgr* device_mgr() const = 0;
922 
923   // Returns the function library definition that backs this runtime.
924   //
925   // NOTE(mrry): The returned library definition is the default function library
926   // for this runtime. The caller may override the function library used by the
927   // runtime to instantiate functions, which will not be reflected in the return
928   // value of this function.
929   virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
930       const = 0;
931 
932   // Returns the environment on which the function executes.
933   virtual Env* env() = 0;
934 
935   // Returns the ConfigProto passed to the session used to create the function.
936   virtual const ConfigProto* const config_proto() = 0;
937 
938   // Returns a debug string showing the definition of the function of
939   // 'handle'.
940   virtual std::string DebugString(Handle handle) = 0;
941 
942   // Returns the graph version number.
943   virtual int graph_def_version() const = 0;
944 
945   typedef uint64 LocalHandle;
946 
947   // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to
948   // the caller), FunctionLibraryRuntime (owned by the returned
949   // ProcessFunctionLibraryRuntime), FunctionLibraryDefinition (transferring
950   // ownership to the caller). Note that both the ProcessFunctionLibraryRuntime
951   // and FunctionLibraryRuntime borrow a pointer to the
952   // FunctionLibraryDefinition and so the FunctionLibraryDefinition should
953   // outlive both.
954   //
955   // The `skip_flib_def` argument controls whether the method should clone the
956   // FunctionLibraryDefinition (default behavior) or return an empty function
957   // library. The latter is used by tf.data, which manages
958   // FunctionLibraryDefinitions for its functions independently (and passes
959   // these into the FunctionLibraryRuntime through an overlay), to avoid linear
960   // runtime w.r.t. to number of functions in the current function library.
961   virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
962                        std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
963                        FunctionLibraryRuntime** out_flr,
964                        bool skip_flib_def = false) = 0;
965 
966   // Returns the name of the executor class (in the sense of
967   // `ExecutorFactory::GetFactory()`) that will be used based on the given
968   // dynamic `options` and static `attrs`. If none is specified, this method
969   // will return an empty string, which leaves the decision up to the runtime.
970   static std::string ExecutorType(const InstantiateOptions& options,
971                                   AttrSlice attrs);
972 };
973 
974 // Returns the device of the `arg_index`-th function input. Update
975 // `composite_devices` if the input device is a composite device.
976 std::string GetFunctionResourceInputDevice(
977     const Tensor& input, const int arg_index, const FunctionDef& function_def,
978     absl::flat_hash_map<string, std::vector<string>>* composite_devices);
979 
980 // Returns a canonicalized string for the instantiation of the function of the
981 // given "name", attributes "attrs", and "options".
982 //
983 // The returned string is guaranteed to be stable within one address space. But
984 // it may be change as the implementation evolves. Therefore, it should not be
985 // persisted or compared across address spaces.
986 std::string Canonicalize(
987     const std::string& funcname, AttrSlice attrs,
988     const FunctionLibraryRuntime::InstantiateOptions& options);
989 std::string Canonicalize(const std::string& funcname, AttrSlice attrs);
990 
991 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
992 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
993 
994 class CustomKernelCreator {
995  public:
~CustomKernelCreator()996   virtual ~CustomKernelCreator() {}
997 
998   // Given a NodeDef 'node_def' and the function library runtime 'flr',
999   // validate if the class supports creating such a kernel.
1000   virtual bool CanCreateKernel(
1001       const FunctionLibraryRuntime& flr,
1002       const std::shared_ptr<const NodeProperties>& props) const = 0;
1003 
1004   // Given a supported NodeDef, returns a kernel that computes the node.
1005   virtual Status CreateKernel(
1006       FunctionLibraryRuntime* flr,
1007       const std::shared_ptr<const NodeProperties>& props,
1008       std::unique_ptr<OpKernel>* kernel) const = 0;
1009 };
1010 
1011 typedef
1012 #if !defined(IS_MOBILE_PLATFORM)
1013     absl::variant<Tensor, eager::RemoteTensorHandle*>
1014         FunctionArg;
1015 #else
1016     absl::variant<Tensor>
1017         FunctionArg;
1018 #endif
1019 
1020 // Either a local tensor or the shape of a remote tensor.
1021 typedef absl::variant<Tensor, TensorShape> FunctionRet;
1022 
1023 // Used to instantiate and run functions in a distributed system.
1024 class DistributedFunctionLibraryRuntime {
1025  public:
~DistributedFunctionLibraryRuntime()1026   virtual ~DistributedFunctionLibraryRuntime() {}
1027 
1028   // Instantiate a function on a remote target specified in `options.target`, by
1029   // sending the name and definition of the function to the remote worker. The
1030   // local `handle` is filled for the instantiated function data and can be used
1031   // for subsequent run function calls on the remote target.
1032   virtual void Instantiate(
1033       const std::string& function_name,
1034       const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
1035       const FunctionLibraryRuntime::InstantiateOptions& options,
1036       FunctionLibraryRuntime::LocalHandle* handle,
1037       FunctionLibraryRuntime::DoneCallback done) = 0;
1038 
1039   // Run an instantiated remote function (specified by `handle`) with a list of
1040   // input Tensors in `args` and get its output Tensors in `rets`. The input
1041   // tensor data will be sent with the function execution request, and must be
1042   // available on the current caller side.
1043   // opts.runner isn't used for execution.
1044   virtual void Run(const FunctionLibraryRuntime::Options& opts,
1045                    FunctionLibraryRuntime::LocalHandle handle,
1046                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
1047                    FunctionLibraryRuntime::DoneCallback done) = 0;
1048 
1049   // Run an instantiated remote function (specified by `handle`) with a list of
1050   // input Tensors or RemoteTensorHandles as `args` and get its output Tensors
1051   // or TensorShapes in `rets`. When using RemoteTensorHandles as function
1052   // inputs or TensorShapes as outputs, the corresponding tensor data will be
1053   // resolved on the remote worker, so it is not required to be locally
1054   // available on the caller side. Using RemoteTensorHandle inputs is not
1055   // supported in TensorFlow v1 runtime.
1056   virtual void Run(const FunctionLibraryRuntime::Options& opts,
1057                    FunctionLibraryRuntime::LocalHandle handle,
1058                    gtl::ArraySlice<FunctionArg> args,
1059                    std::vector<FunctionRet>* rets,
1060                    FunctionLibraryRuntime::DoneCallback done) = 0;
1061 
1062   // Clean up a previously instantiated function on remote worker.
1063   virtual void CleanUp(uint64 step_id,
1064                        FunctionLibraryRuntime::LocalHandle handle,
1065                        FunctionLibraryRuntime::DoneCallback done) = 0;
1066 
1067   // DeviceMgr with *all* available devices (i.e., local and remote).
1068   virtual DeviceMgr* remote_device_mgr() const = 0;
1069 };
1070 
1071 // Extracts the actual type from "attr_values" based on its definition
1072 // "arg_def".
1073 //
1074 // If "arg_def" is a N*T type, *is_type_list is set to false, and
1075 // *dtypes is set to be a vector of size N and each element is T.
1076 //
1077 // If "arg_def" is a list(type), *is_type_list is set to true, and
1078 // *dtypes is set to be a vector of types specified in attrs for
1079 // arg_def.
1080 //
1081 // Otherwise (arg_def is a simple type T), *is_type_list is set to
1082 // false, and *dtypes is set to a single element vector, whose only
1083 // element is T.
1084 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
1085                   bool* is_type_list, DataTypeVector* dtypes);
1086 
1087 // To register a gradient function for a builtin op, one should use
1088 //   REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
1089 //
1090 // Typically, the c++ grad factory is a plan function that can be
1091 // converted into ::tensorflow::gradient::Creator, which is
1092 //   std::function<Status(const AttrSlice&, FunctionDef*)>.
1093 //
1094 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
1095 // definition of a brain function which compute the gradient for the
1096 // <op_name> when the <op_name> is instantiated with the given attrs.
1097 //
1098 // E.g.,
1099 //
1100 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
1101 //   bool transpose_a;
1102 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
1103 //   bool transpose_b;
1104 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
1105 //   DataType dtype;
1106 //   TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
1107 //   if (!transpose_a && !transpose_b) {
1108 //     *g = FunctionDefHelper::Define(
1109 //       "MatMulGrad",
1110 //       {"x:T ", "y:T", "dz:T"},    // Inputs to this function
1111 //       {"dx:T", "dy:T"},           // Outputs from this function
1112 //       {"T: {float, double}"},     // Attributes needed by this function
1113 //       {
1114 //         {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
1115 //         {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
1116 //         {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
1117 //         {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
1118 //       });
1119 //   } else {
1120 //     ... ...
1121 //   }
1122 //   return Status::OK();
1123 // }
1124 //
1125 // NOTE: $T is substituted with the type variable "T" when the
1126 // gradient function MatMul is instantiated.
1127 //
1128 // TODO(zhifengc): Better documentation somewhere.
1129 
1130 // Macros to define a gradient function factory for a primitive
1131 // operation.
1132 #define REGISTER_OP_GRADIENT(name, fn) \
1133   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
1134 
1135 #define REGISTER_OP_NO_GRADIENT(name) \
1136   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
1137 
1138 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
1139   REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
1140 
1141 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)      \
1142   static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \
1143       SHOULD_REGISTER_OP_GRADIENT &&                  \
1144       ::tensorflow::gradient::RegisterOp(name, fn)
1145 
1146 namespace gradient {
1147 // Register a gradient creator for the "op".
1148 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
1149 bool RegisterOp(const std::string& op, Creator func);
1150 
1151 // Returns OK the gradient creator for the "op" is found (may be
1152 // nullptr if REGISTER_OP_NO_GRADIENT is used.
1153 Status GetOpGradientCreator(const std::string& op, Creator* creator);
1154 };  // namespace gradient
1155 
1156 // Declare explicit instantiations of GetAttr
1157 #define GET_ATTR(T)                                          \
1158   extern template Status FunctionLibraryDefinition::GetAttr( \
1159       const Node&, const string&, T*) const;                 \
1160   extern template Status FunctionLibraryDefinition::GetAttr( \
1161       const NodeDef&, const string&, T*) const;
1162 GET_ATTR(string)
1163 GET_ATTR(bool)
1164 #undef GET_ATTR
1165 
1166 }  // end namespace tensorflow
1167 
1168 #endif  // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
1169