• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
17 #define TENSORFLOW_FRAMEWORK_FUNCTION_H_
18 
19 #include <vector>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/selective_registration.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/gtl/flatmap.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 
33 namespace tensorflow {
34 
35 class CancellationManager;
36 class GraphDef;
37 class OpKernel;
38 class ProcessFunctionLibraryRuntime;
39 class ResourceMgr;
40 class Rendezvous;
41 class ScopedStepContainer;
42 class StepStatsCollector;
43 class Node;
44 
45 // FunctionDefHelper::Create is a convenient helper to construct a
46 // FunctionDef proto.
47 // E.g.,
48 //   FunctionDef my_func = FunctionDefHelper::Create(
49 //     "my_func_name",
50 //     {"x:T", "y:T" /* one string per argument */},
51 //     {"z:T" /* one string per return value */},
52 //     {"T: {float, double}" /* one string per attribute  */},
53 //     {
54 //        {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
55 //        /* one entry per function node */
56 //     },
57 //     /* Mapping between function returns and function node outputs. */
58 //     {{"z", "o:z"}});
59 //
60 // For the old Function::Node approach, use FunctionDefHelper::Define()
61 // E.g.,
62 //   FunctionDef my_func = FunctionDefHelper::Define(
63 //     "my_func_name",
64 //     {"x:T", "y:T" /* one string per argument */},
65 //     {"z:T" /* one string per return value */},
66 //     {"T: {float, double}" /* one string per attribute  */},
67 //     {
68 //        {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
69 //        /* one entry per function node */
70 //     });
71 class FunctionDefHelper {
72  public:
73   // AttrValueWrapper has copy constructors for the type T so that
74   // it's easy to construct a simple AttrValue proto.
75   //
76   // If T is a string type (const char*, string, or StringPiece), and
77   // it starts with "$", we construct a AttrValue of "placeholder".
78   //
79   // E.g.,
80   //   std::<string, AttrValueWrapper> x = {"T", "$T"}
81   // is a named attr value placeholder.
82   struct AttrValueWrapper {
83     AttrValue proto;
84 
AttrValueWrapperAttrValueWrapper85     AttrValueWrapper() {}
86 
87     template <typename T>
AttrValueWrapperAttrValueWrapper88     AttrValueWrapper(T val) {  // NOLINT(runtime/explicit)
89       SetAttrValue(val, &proto);
90     }
91 
92    private:
93     void InitFromString(StringPiece val);
94   };
95 
96   // Constructs an AttrValue.func given the "name" and "attrs".
97   static AttrValueWrapper FunctionRef(
98       const string& name,
99       gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
FunctionRef(const string & name)100   static AttrValueWrapper FunctionRef(const string& name) {
101     return FunctionRef(name, {});
102   }
103 
104   // Node is used to construct FunctionDef.Node using initialization
105   // lists. E.g.,
106   //  Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}};  // z = x * y
107   struct Node {
108     // When constructing a NodeDef, the first entry in ret is used as
109     // the node name, the remaining values are ignored.
110     std::vector<string> ret;
111     string op;
112     std::vector<string> arg;
113     std::vector<std::pair<string, AttrValueWrapper>> attr;
114     std::vector<string> dep;
115 
116     NodeDef ToNodeDef() const;
117   };
118 
119   // The Create() function uses the new NodeDef field.  `ret_def`
120   // holds a mapping from the function output names from `out_def` to
121   // the node outputs from `node_def`.
122   static FunctionDef Create(const string& function_name,
123                             gtl::ArraySlice<string> in_def,
124                             gtl::ArraySlice<string> out_def,
125                             gtl::ArraySlice<string> attr_def,
126                             gtl::ArraySlice<Node> node_def,
127                             gtl::ArraySlice<std::pair<string, string>> ret_def);
128 
129   // The two Define() functions use the old FunctionDef::Node field.
130   // TODO(josh11b): Get rid of these and transition to the one above.
131   static FunctionDef Define(const string& function_name,
132                             gtl::ArraySlice<string> arg_def,
133                             gtl::ArraySlice<string> ret_def,
134                             gtl::ArraySlice<string> attr_def,
135                             gtl::ArraySlice<Node> node_def);
136 
137   // Defines an anonymous function. I.e., its name is not relevant.
138   static FunctionDef Define(gtl::ArraySlice<string> arg_def,
139                             gtl::ArraySlice<string> ret_def,
140                             gtl::ArraySlice<string> attr_def,
141                             gtl::ArraySlice<Node> node_def);
142 
143   // Helpers to construct a constant scalar.
144   template <typename T>
Const(const string & name,const T & val)145   static Node Const(const string& name, const T& val) {
146     Node n = {{name}, "Const"};
147     const DataType dtype = DataTypeToEnum<T>::value;
148     n.attr.push_back({"dtype", dtype});
149     Tensor t(dtype, TensorShape({}));
150     t.scalar<T>()() = val;
151     n.attr.push_back({"value", t});
152     return n;
153   }
154 
155   template <typename T>
Const(const string & name,gtl::ArraySlice<T> vals)156   static Node Const(const string& name, gtl::ArraySlice<T> vals) {
157     Node n = {{name}, "Const"};
158     const DataType dtype = DataTypeToEnum<T>::value;
159     n.attr.push_back({"dtype", dtype});
160     int64 num = vals.size();
161     Tensor t(dtype, TensorShape({num}));
162     for (size_t i = 0; i < vals.size(); ++i) {
163       t.flat<T>()(i) = vals[i];
164     }
165     n.attr.push_back({"value", t});
166     return n;
167   }
168 };
169 
170 template <>
AttrValueWrapper(const char * val)171 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
172   InitFromString(val);
173 }
174 
175 template <>
AttrValueWrapper(const string & val)176 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
177     const string& val) {
178   InitFromString(val);
179 }
180 
181 template <>
AttrValueWrapper(StringPiece val)182 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
183   InitFromString(val);
184 }
185 
186 // Instantiate a function.
187 //
188 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
189 // containing placeholders.  InstantiateFunction binds these
190 // placeholders and produces an instantiated function encoded in
191 // "result.gdef". The value to substitute a placeholder is given by
192 // "attr_values", which is a map from a placeholder name to an attr
193 // value.
194 //
195 // InstantiateFunction calls "get_function" to find signatures of other
196 // functions and primitive ops.
197 
198 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
199 // and opdef is filled with a pointer to the corresponding signature
200 // (a OpDef proto). Otherwise, returns an error.
201 typedef std::function<Status(const string&, const OpDef**)>
202     GetFunctionSignature;
203 
204 struct InstantiationResult {
205   DataTypeVector arg_types;
206   DataTypeVector ret_types;
207   std::vector<NodeDef> nodes;
208 };
209 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
210                            GetFunctionSignature get_function,
211                            InstantiationResult* result);
212 
213 // Returns a debug string for a function definition.
214 //
215 // The returned text is multiple-line. It is intended to be
216 // human-readable rather than being friendly to parsers. It is _NOT_
217 // intended to be the canonical string representation of "func_def".
218 // Particularly, it may not include all information presented in
219 // "func_def" (e.g., comments, description of the function arguments,
220 // etc.)
221 string DebugString(const FunctionDef& func_def);
222 string DebugString(const GraphDef& instantiated_func_def);
223 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
224 
225 // Returns a debug string for a top level graph (the main program and
226 // its supporting functions defined in its library).
227 string DebugStringWhole(const GraphDef& gdef);
228 
229 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
230 // of NodeDefs doesn't matter.
231 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
232 
233 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
234 // In other words, if two fdefs compare equal, their hash values will be the
235 // same.
236 uint64 FunctionDefHash(const FunctionDef& fdef);
237 
238 class CallFrameInterface {
239  public:
~CallFrameInterface()240   virtual ~CallFrameInterface() {}
241 
242   virtual size_t num_args() const = 0;
243   virtual size_t num_retvals() const = 0;
244 
245   virtual Status GetArg(int index, Tensor* val) const = 0;
246   virtual Status SetRetval(int index, const Tensor& val) = 0;
247 };
248 
249 // Represents a function call frame. I.e., the data structure used to
250 // pass arguments to a function and retrieve its results.
251 //
252 // Runtime must arrange accesses to one FunctionCallFrame s.t.
253 //   1. SetArgs() happens before any GetArg();
254 //   2. GetRetvals happens after all SetRetval();
255 class FunctionCallFrame : public CallFrameInterface {
256  public:
257   FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
258   ~FunctionCallFrame();
259 
260   // Caller methods.
261   Status SetArgs(gtl::ArraySlice<Tensor> args);
262   Status GetRetvals(std::vector<Tensor>* rets) const;
263   Status ConsumeRetvals(std::vector<Tensor>* rets);
264 
num_args()265   size_t num_args() const override { return arg_types_.size(); }
num_retvals()266   size_t num_retvals() const override { return ret_types_.size(); }
267 
268   // Callee methods.
269   Status GetArg(int index, Tensor* val) const override;
270   Status SetRetval(int index, const Tensor& val) override;
271 
272  private:
273   DataTypeVector arg_types_;
274   DataTypeVector ret_types_;
275   gtl::InlinedVector<Tensor, 4> args_;
276   struct Retval {
277     bool has_val = false;
278     Tensor val;
279   };
280   gtl::InlinedVector<Retval, 4> rets_;
281 
282   TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
283 };
284 
285 // Helper to maintain a map between function names in a given
286 // FunctionDefLibrary and function definitions.
287 class FunctionLibraryDefinition : public OpRegistryInterface {
288  public:
289   explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
290   FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
291                             const FunctionDefLibrary& lib_def);
292   ~FunctionLibraryDefinition() override;
293 
294   FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
295       delete;
296 
297   // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
298   // returns its definition proto.
299   const FunctionDef* Find(const string& func) const;
300 
301   // Adds function definition 'fdef' to this function library.
302   // Returns status 'ok' on success, or error otherwise. This is a no-op if
303   // 'fdef' already exists in this function library.
304   // If 'fdef' is successfully added to the library, it will be accessible
305   // from 'LookUp' and included in the proto returned by 'ToProto'.
306   // This operation is atomic.
307   Status AddFunctionDef(const FunctionDef& fdef);
308 
309   // Adds gradient definition 'grad' to this function library.
310   // This is a no-op if 'grad' already exists in this function library.
311   // If 'grad' is successfully added, it will be accessible via 'FindGradient'
312   // and included in the proto returned by 'ToProto'.
313   // This operation is atomic.
314   Status AddGradientDef(const GradientDef& grad);
315 
316   // Remove function `func` from the library. Returns non-OK Status unless
317   // `func` is in the library.
318   Status RemoveFunction(const string& func);
319 
320   // Remove gradient of function `func` from the library. Returns non-OK Status
321   // unless `func` has a gradient.
322   Status RemoveGradient(const string& func);
323 
324   // Adds the functions and gradients in 'other' to this function library.
325   // Duplicate functions and gradients are ignored.
326   // This operation is atomic.
327   Status AddLibrary(const FunctionLibraryDefinition& other);
328 
329   // Adds the functions and gradients in 'lib_def' to this function library.
330   // Duplicate functions and gradients are ignored.
331   // This operation is atomic.
332   Status AddLibrary(const FunctionDefLibrary& lib_def);
333 
334   // If the gradient function for 'func' is specified explicitly in
335   // the library, returns the gradient function name.  Otherwise,
336   // returns an empty string.
337   string FindGradient(const string& func) const;
338 
339   // OpRegistryInterface method. Useful for constructing a Graph.
340   //
341   // If "op" is defined in the library, returns its signature.
342   // Otherwise, assume "op" is a primitive op and returns its op
343   // signature and shape inference function.
344   Status LookUp(const string& op_type_name,
345                 const OpRegistrationData** op_reg_data) const override;
346 
347   static constexpr const char* const kGradientOp = "SymbolicGradient";
348   static constexpr const char* const kFuncAttr = "f";
349 
350   // Given a node def 'ndef', inspects attributes of the callee
351   // function to derive the attribute 'value' for 'attr'. Returns OK
352   // iff the attribute is given by the function's definition.
353   // TODO(irving): Remove; keep only the const Node& version.
354   template <typename T>
355   Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
356 
357   // Given a node, inspects attributes of the callee function to derive the
358   // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
359   // function's definition.
360   template <typename T>
361   Status GetAttr(const Node& node, const string& attr, T* value) const;
362 
363   // Returns a proto representation of the state of this function library.
364   FunctionDefLibrary ToProto() const;
365 
num_functions()366   size_t num_functions() const { return function_defs_.size(); }
367 
default_registry()368   const OpRegistryInterface* default_registry() const {
369     return default_registry_;
370   }
371 
372  private:
373   // Shape inference for functions is handled separately by ShapeRefiner.
374 
375   struct FunctionDefAndOpRegistration {
376     FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
377 
378     FunctionDef fdef;
379     OpRegistrationData op_registration_data;
380   };
381 
382   // Same as AddFunctionDef/AddGradientDef except these methods set
383   // `added` to true if the `fdef`/`grad` were actually added to this.
384   Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added);
385   Status AddGradientDefHelper(const GradientDef& grad, bool* added);
386 
387   const OpRegistryInterface* const default_registry_;
388   gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>>
389       function_defs_;
390   gtl::FlatMap<string, string> func_grad_;
391 
392   // Helper function for GetAttr. Returns the FunctionDef* to get the
393   // attr from.
394   const FunctionDef* GetAttrImpl(const NodeDef& ndef) const;
395 
396   // Remove all functions in `funcs` and all gradients of
397   // functions in `funcs_with_grads` from this library.
398   void Remove(const std::vector<string>& funcs,
399               const std::vector<string>& funcs_with_grads);
400 };
401 
402 // Forward declare. Defined in common_runtime/function.h
403 struct FunctionBody;
404 
405 // Forward declare. Defined in common_runtime/device.h
406 class Device;
407 
408 class FunctionLibraryRuntime {
409  public:
~FunctionLibraryRuntime()410   virtual ~FunctionLibraryRuntime() {}
411 
412   // Instantiate a function with the given "attrs".
413   //
414   // Returns OK and fills in "handle" if the instantiation succeeds.
415   // Otherwise returns an error and "handle" is undefined.
416   struct InstantiateOptions {
417     // The canonical device name of the device on which the function
418     // should be instantiated. If empty, the function will be
419     // instantiated on the local device.
420     string target;
421 
422     // This interface is EXPERIMENTAL and subject to change.
423     //
424     // If non-null, the runtime will use `overlay_lib` to resolve
425     // function(s) named in `function_name` and `attrs`. Otherwise,
426     // the runtime will use its internal library.
427     // NOTE(mrry): If provided, all functions defined in `overlay_lib`
428     // must be self-contained, and cannot refer to functions defined
429     // in other libraries.
430     // TODO(mrry): Provide a mechanism for sharing core functions
431     // between a set of libraries (e.g. by allowing a
432     // `FunctionLibraryDefinition` to store an `outer_scope` pointer
433     // and implementing name resolution across libraries).
434     const FunctionLibraryDefinition* overlay_lib = nullptr;
435 
436     // This interface is EXPERIMENTAL and subject to change.
437     //
438     // If non-empty, the runtime will use `state_handle` to identify
439     // cached state related the instantiated function. Two functions
440     // of the same name and attrs, instantiated with the same
441     // `state_handle` will have the same handle and share the same
442     // state (in stateful kernels); and two functions with different
443     // values for `state_handle` will have independent state.
444     string state_handle;
445   };
446   typedef uint64 Handle;
447   virtual Status Instantiate(const string& function_name, AttrSlice attrs,
448                              const InstantiateOptions& options,
449                              Handle* handle) = 0;
Instantiate(const string & function_name,AttrSlice attrs,Handle * handle)450   Status Instantiate(const string& function_name, AttrSlice attrs,
451                      Handle* handle) {
452     return Instantiate(function_name, attrs, {}, handle);
453   }
454 
455   // Releases state associated with the handle.
456   virtual Status ReleaseHandle(Handle handle) = 0;
457 
458   // Returns the function body for the instantiated function given its
459   // handle 'h'. Returns nullptr if "h" is not found.
460   //
461   // *this keeps the ownership of the returned object, which remains alive
462   // as long as *this.
463   virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
464 
465   // Asynchronously invokes the instantiated function identified by
466   // "handle".
467   //
468   // If function execution succeeds, "done" is called with OK and
469   // "*rets" is filled with the function's return values. Otheriwse,
470   // "done" is called with an error status.
471   //
472   // Does not take ownership of "rets".
473   // In the cross-process scenario, runner isn't used for making the Async
474   // RPC calls.
475   struct Options {
476     // The id of the step that is calling this function.
477     int64 step_id = 0;
478     Rendezvous* rendezvous = nullptr;
479     CancellationManager* cancellation_manager = nullptr;
480     ScopedStepContainer* step_container = nullptr;
481     StepStatsCollector* stats_collector = nullptr;
482 
483     std::function<void(std::function<void()>)>* runner = nullptr;
484 
485     // Parameters for remote function execution.
486     bool remote_execution = false;
487     string source_device = "";  // Fully specified device name.
488 
489     // Allocator attributes specifying where the args are / rets should be put.
490     // These should either be {} or match the length of args / retvals. If {},
491     // the default allocator attributes will be assumed for all args / retvals.
492     std::vector<AllocatorAttributes> args_alloc_attrs;
493     std::vector<AllocatorAttributes> rets_alloc_attrs;
494 
495     // If true, we create a new IntraProcessRendezvous, else use the existing
496     // one.
497     bool create_rendezvous = false;
498   };
499   typedef std::function<void(const Status&)> DoneCallback;
500   virtual void Run(const Options& opts, Handle handle,
501                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
502                    DoneCallback done) = 0;
503   virtual void Run(const Options& opts, Handle handle,
504                    CallFrameInterface* call_frame, DoneCallback done) = 0;
505 
506   // Creates a "kernel" for the given node def "ndef".
507   //
508   // If succeeds, returns OK and the caller takes the ownership of the
509   // returned "*kernel". Otherwise, returns an error.
510   virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0;
511 
512   // Returns true iff the function named `function_name` is stateful.
513   // NOTE(mrry): This method assumes that the runtime is associated with a
514   // default function library, and looks up `function_name` in that library.
515   // It does not support overlay libraries.
516   virtual bool IsStateful(const string& function_name) = 0;
517 
518   // Returns the device on which the function executes.
519   virtual Device* device() = 0;
520 
521   // Returns the function library definition that backs this runtime.
522   // NOTE(mrry): The returned library definition is the default function library
523   // for this runtime. The runtime may instantiate functions from separate
524   // overlay libraries, which are not returned by this function.
525   virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
526       const = 0;
527 
528   // Returns the environment on which the function executes.
529   virtual Env* env() = 0;
530 
531   // Returns a debug string showing the definition of the function of
532   // 'handle'.
533   virtual string DebugString(Handle handle) = 0;
534 
535   // Returns the graph version number.
536   virtual int graph_def_version() = 0;
537 
538   typedef uint64 LocalHandle;
539 
540   virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
541                        std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
542                        FunctionLibraryRuntime** out_flr) = 0;
543 };
544 
545 // Returns a canonicalized string for the instantiation of the
546 // function of the given "name", attributes "attrs", and "options".
547 //
548 // The returned string is guaranteed to be stable within one address
549 // space. But it may be change as the implementation
550 // evolves. Therefore, it should not be persisted or compared across
551 // address spaces.
552 string Canonicalize(const string& funcname, AttrSlice attrs,
553                     const FunctionLibraryRuntime::InstantiateOptions& options);
Canonicalize(const string & funcname,AttrSlice attrs)554 inline string Canonicalize(const string& funcname, AttrSlice attrs) {
555   return Canonicalize(funcname, attrs, {});
556 }
557 
558 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
559 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
560 typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
561                              std::unique_ptr<OpKernel>*)>
562     CustomKernelCreator;
563 
564 // Used to instantiate and run functions in a distributed system.
565 class DistributedFunctionLibraryRuntime {
566  public:
~DistributedFunctionLibraryRuntime()567   virtual ~DistributedFunctionLibraryRuntime() {}
568 
569   // The _target attr in attrs determines where the function is instantiated.
570   virtual Status Instantiate(
571       const string& function_name, const FunctionLibraryDefinition& lib_def,
572       AttrSlice attrs,
573       const FunctionLibraryRuntime::InstantiateOptions& options,
574       FunctionLibraryRuntime::LocalHandle* handle) = 0;
575 
576   // opts.runner isn't used for execution.
577   virtual void Run(const FunctionLibraryRuntime::Options& opts,
578                    FunctionLibraryRuntime::LocalHandle handle,
579                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
580                    FunctionLibraryRuntime::DoneCallback done) = 0;
581 };
582 
583 // Extracts the actual type from "attr_values" based on its definition
584 // "arg_def".
585 //
586 // If "arg_def" is a N*T type, *is_type_list is set to false, and
587 // *dtypes is set to be a vector of size N and each element is T.
588 //
589 // If "arg_def" is a list(type), *is_type_list is set to true, and
590 // *dtypes is set to be a vector of types specified in attrs for
591 // arg_def.
592 //
593 // Otherwise (arg_def is a simple type T), *is_type_list is set to
594 // false, and *dtypes is set to a single element vector, whose only
595 // element is T.
596 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
597                   bool* is_type_list, DataTypeVector* dtypes);
598 
599 // To register a gradient function for a builtin op, one should use
600 //   REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
601 //
602 // Typically, the c++ grad factory is a plan function that can be
603 // converted into ::tensorflow::gradient::Creator, which is
604 //   std::function<Status(const AttrSlice&, FunctionDef*)>.
605 //
606 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
607 // definition of a brain function which compute the gradient for the
608 // <op_name> when the <op_name> is instantiated with the given attrs.
609 //
610 // E.g.,
611 //
612 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
613 //   bool transpose_a;
614 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
615 //   bool transpose_b;
616 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
617 //   DataType dtype;
618 //   TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
619 //   if (!transpose_a && !transpose_b) {
620 //     *g = FunctionDefHelper::Define(
621 //       "MatMulGrad",
622 //       {"x:T ", "y:T", "dz:T"},    // Inputs to this function
623 //       {"dx:T", "dy:T"},           // Outputs from this function
624 //       {"T: {float, double}"},     // Attributes needed by this function
625 //       {
626 //         {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
627 //         {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
628 //         {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
629 //         {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
630 //       });
631 //   } else {
632 //     ... ...
633 //   }
634 //   return Status::OK();
635 // }
636 //
637 // NOTE: $T is substituted with the type variable "T" when the
638 // gradient function MatMul is instantiated.
639 //
640 // TODO(zhifengc): Better documentation somewhere.
641 
642 // Macros to define a gradient function factory for a primitive
643 // operation.
644 #define REGISTER_OP_GRADIENT(name, fn) \
645   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
646 
647 #define REGISTER_OP_NO_GRADIENT(name) \
648   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
649 
650 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
651   REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
652 
653 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)                 \
654   static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \
655                                   ::tensorflow::gradient::RegisterOp(name, fn)
656 
657 namespace gradient {
658 // Register a gradient creator for the "op".
659 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
660 bool RegisterOp(const string& op, Creator func);
661 
662 // Returns OK the gradient creator for the "op" is found (may be
663 // nullptr if REGISTER_OP_NO_GRADIENT is used.
664 Status GetOpGradientCreator(const string& op, Creator* creator);
665 };  // namespace gradient
666 
667 // Declare explicit instantiations of GetAttr
668 #define GET_ATTR(T)                                          \
669   extern template Status FunctionLibraryDefinition::GetAttr( \
670       const Node&, const string&, T*) const;                 \
671   extern template Status FunctionLibraryDefinition::GetAttr( \
672       const NodeDef&, const string&, T*) const;
673 GET_ATTR(string)
674 GET_ATTR(bool)
675 #undef GET_ATTR
676 
677 }  // end namespace tensorflow
678 
679 #endif  // TENSORFLOW_FRAMEWORK_FUNCTION_H_
680