1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
17 #define TENSORFLOW_FRAMEWORK_FUNCTION_H_
18
19 #include <vector>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/selective_registration.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/gtl/flatmap.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/protobuf.h"
32
33 namespace tensorflow {
34
35 class CancellationManager;
36 class GraphDef;
37 class OpKernel;
38 class ProcessFunctionLibraryRuntime;
39 class ResourceMgr;
40 class Rendezvous;
41 class ScopedStepContainer;
42 class StepStatsCollector;
43 class Node;
44
45 // FunctionDefHelper::Create is a convenient helper to construct a
46 // FunctionDef proto.
47 // E.g.,
48 // FunctionDef my_func = FunctionDefHelper::Create(
49 // "my_func_name",
50 // {"x:T", "y:T" /* one string per argument */},
51 // {"z:T" /* one string per return value */},
52 // {"T: {float, double}" /* one string per attribute */},
53 // {
54 // {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
55 // /* one entry per function node */
56 // },
57 // /* Mapping between function returns and function node outputs. */
58 // {{"z", "o:z"}});
59 //
60 // For the old Function::Node approach, use FunctionDefHelper::Define()
61 // E.g.,
62 // FunctionDef my_func = FunctionDefHelper::Define(
63 // "my_func_name",
64 // {"x:T", "y:T" /* one string per argument */},
65 // {"z:T" /* one string per return value */},
66 // {"T: {float, double}" /* one string per attribute */},
67 // {
68 // {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
69 // /* one entry per function node */
70 // });
71 class FunctionDefHelper {
72 public:
73 // AttrValueWrapper has copy constructors for the type T so that
74 // it's easy to construct a simple AttrValue proto.
75 //
76 // If T is a string type (const char*, string, or StringPiece), and
77 // it starts with "$", we construct a AttrValue of "placeholder".
78 //
79 // E.g.,
80 // std::<string, AttrValueWrapper> x = {"T", "$T"}
81 // is a named attr value placeholder.
82 struct AttrValueWrapper {
83 AttrValue proto;
84
AttrValueWrapperAttrValueWrapper85 AttrValueWrapper() {}
86
87 template <typename T>
AttrValueWrapperAttrValueWrapper88 AttrValueWrapper(T val) { // NOLINT(runtime/explicit)
89 SetAttrValue(val, &proto);
90 }
91
92 private:
93 void InitFromString(StringPiece val);
94 };
95
96 // Constructs an AttrValue.func given the "name" and "attrs".
97 static AttrValueWrapper FunctionRef(
98 const string& name,
99 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
FunctionRef(const string & name)100 static AttrValueWrapper FunctionRef(const string& name) {
101 return FunctionRef(name, {});
102 }
103
104 // Node is used to construct FunctionDef.Node using initialization
105 // lists. E.g.,
106 // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y
107 struct Node {
108 // When constructing a NodeDef, the first entry in ret is used as
109 // the node name, the remaining values are ignored.
110 std::vector<string> ret;
111 string op;
112 std::vector<string> arg;
113 std::vector<std::pair<string, AttrValueWrapper>> attr;
114 std::vector<string> dep;
115
116 NodeDef ToNodeDef() const;
117 };
118
119 // The Create() function uses the new NodeDef field. `ret_def`
120 // holds a mapping from the function output names from `out_def` to
121 // the node outputs from `node_def`.
122 static FunctionDef Create(const string& function_name,
123 gtl::ArraySlice<string> in_def,
124 gtl::ArraySlice<string> out_def,
125 gtl::ArraySlice<string> attr_def,
126 gtl::ArraySlice<Node> node_def,
127 gtl::ArraySlice<std::pair<string, string>> ret_def);
128
129 // The two Define() functions use the old FunctionDef::Node field.
130 // TODO(josh11b): Get rid of these and transition to the one above.
131 static FunctionDef Define(const string& function_name,
132 gtl::ArraySlice<string> arg_def,
133 gtl::ArraySlice<string> ret_def,
134 gtl::ArraySlice<string> attr_def,
135 gtl::ArraySlice<Node> node_def);
136
137 // Defines an anonymous function. I.e., its name is not relevant.
138 static FunctionDef Define(gtl::ArraySlice<string> arg_def,
139 gtl::ArraySlice<string> ret_def,
140 gtl::ArraySlice<string> attr_def,
141 gtl::ArraySlice<Node> node_def);
142
143 // Helpers to construct a constant scalar.
144 template <typename T>
Const(const string & name,const T & val)145 static Node Const(const string& name, const T& val) {
146 Node n = {{name}, "Const"};
147 const DataType dtype = DataTypeToEnum<T>::value;
148 n.attr.push_back({"dtype", dtype});
149 Tensor t(dtype, TensorShape({}));
150 t.scalar<T>()() = val;
151 n.attr.push_back({"value", t});
152 return n;
153 }
154
155 template <typename T>
Const(const string & name,gtl::ArraySlice<T> vals)156 static Node Const(const string& name, gtl::ArraySlice<T> vals) {
157 Node n = {{name}, "Const"};
158 const DataType dtype = DataTypeToEnum<T>::value;
159 n.attr.push_back({"dtype", dtype});
160 int64 num = vals.size();
161 Tensor t(dtype, TensorShape({num}));
162 for (size_t i = 0; i < vals.size(); ++i) {
163 t.flat<T>()(i) = vals[i];
164 }
165 n.attr.push_back({"value", t});
166 return n;
167 }
168 };
169
170 template <>
AttrValueWrapper(const char * val)171 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
172 InitFromString(val);
173 }
174
175 template <>
AttrValueWrapper(const string & val)176 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
177 const string& val) {
178 InitFromString(val);
179 }
180
181 template <>
AttrValueWrapper(StringPiece val)182 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
183 InitFromString(val);
184 }
185
186 // Instantiate a function.
187 //
188 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
189 // containing placeholders. InstantiateFunction binds these
190 // placeholders and produces an instantiated function encoded in
191 // "result.gdef". The value to substitute a placeholder is given by
192 // "attr_values", which is a map from a placeholder name to an attr
193 // value.
194 //
195 // InstantiateFunction calls "get_function" to find signatures of other
196 // functions and primitive ops.
197
198 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
199 // and opdef is filled with a pointer to the corresponding signature
200 // (a OpDef proto). Otherwise, returns an error.
201 typedef std::function<Status(const string&, const OpDef**)>
202 GetFunctionSignature;
203
204 struct InstantiationResult {
205 DataTypeVector arg_types;
206 DataTypeVector ret_types;
207 std::vector<NodeDef> nodes;
208 };
209 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
210 GetFunctionSignature get_function,
211 InstantiationResult* result);
212
213 // Returns a debug string for a function definition.
214 //
215 // The returned text is multiple-line. It is intended to be
216 // human-readable rather than being friendly to parsers. It is _NOT_
217 // intended to be the canonical string representation of "func_def".
218 // Particularly, it may not include all information presented in
219 // "func_def" (e.g., comments, description of the function arguments,
220 // etc.)
221 string DebugString(const FunctionDef& func_def);
222 string DebugString(const GraphDef& instantiated_func_def);
223 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
224
225 // Returns a debug string for a top level graph (the main program and
226 // its supporting functions defined in its library).
227 string DebugStringWhole(const GraphDef& gdef);
228
229 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
230 // of NodeDefs doesn't matter.
231 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
232
233 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
234 // In other words, if two fdefs compare equal, their hash values will be the
235 // same.
236 uint64 FunctionDefHash(const FunctionDef& fdef);
237
238 class CallFrameInterface {
239 public:
~CallFrameInterface()240 virtual ~CallFrameInterface() {}
241
242 virtual size_t num_args() const = 0;
243 virtual size_t num_retvals() const = 0;
244
245 virtual Status GetArg(int index, Tensor* val) const = 0;
246 virtual Status SetRetval(int index, const Tensor& val) = 0;
247 };
248
249 // Represents a function call frame. I.e., the data structure used to
250 // pass arguments to a function and retrieve its results.
251 //
252 // Runtime must arrange accesses to one FunctionCallFrame s.t.
253 // 1. SetArgs() happens before any GetArg();
254 // 2. GetRetvals happens after all SetRetval();
255 class FunctionCallFrame : public CallFrameInterface {
256 public:
257 FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
258 ~FunctionCallFrame();
259
260 // Caller methods.
261 Status SetArgs(gtl::ArraySlice<Tensor> args);
262 Status GetRetvals(std::vector<Tensor>* rets) const;
263 Status ConsumeRetvals(std::vector<Tensor>* rets);
264
num_args()265 size_t num_args() const override { return arg_types_.size(); }
num_retvals()266 size_t num_retvals() const override { return ret_types_.size(); }
267
268 // Callee methods.
269 Status GetArg(int index, Tensor* val) const override;
270 Status SetRetval(int index, const Tensor& val) override;
271
272 private:
273 DataTypeVector arg_types_;
274 DataTypeVector ret_types_;
275 gtl::InlinedVector<Tensor, 4> args_;
276 struct Retval {
277 bool has_val = false;
278 Tensor val;
279 };
280 gtl::InlinedVector<Retval, 4> rets_;
281
282 TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
283 };
284
285 // Helper to maintain a map between function names in a given
286 // FunctionDefLibrary and function definitions.
287 class FunctionLibraryDefinition : public OpRegistryInterface {
288 public:
289 explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
290 FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
291 const FunctionDefLibrary& lib_def);
292 ~FunctionLibraryDefinition() override;
293
294 FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
295 delete;
296
297 // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
298 // returns its definition proto.
299 const FunctionDef* Find(const string& func) const;
300
301 // Adds function definition 'fdef' to this function library.
302 // Returns status 'ok' on success, or error otherwise. This is a no-op if
303 // 'fdef' already exists in this function library.
304 // If 'fdef' is successfully added to the library, it will be accessible
305 // from 'LookUp' and included in the proto returned by 'ToProto'.
306 // This operation is atomic.
307 Status AddFunctionDef(const FunctionDef& fdef);
308
309 // Adds gradient definition 'grad' to this function library.
310 // This is a no-op if 'grad' already exists in this function library.
311 // If 'grad' is successfully added, it will be accessible via 'FindGradient'
312 // and included in the proto returned by 'ToProto'.
313 // This operation is atomic.
314 Status AddGradientDef(const GradientDef& grad);
315
316 // Remove function `func` from the library. Returns non-OK Status unless
317 // `func` is in the library.
318 Status RemoveFunction(const string& func);
319
320 // Remove gradient of function `func` from the library. Returns non-OK Status
321 // unless `func` has a gradient.
322 Status RemoveGradient(const string& func);
323
324 // Adds the functions and gradients in 'other' to this function library.
325 // Duplicate functions and gradients are ignored.
326 // This operation is atomic.
327 Status AddLibrary(const FunctionLibraryDefinition& other);
328
329 // Adds the functions and gradients in 'lib_def' to this function library.
330 // Duplicate functions and gradients are ignored.
331 // This operation is atomic.
332 Status AddLibrary(const FunctionDefLibrary& lib_def);
333
334 // If the gradient function for 'func' is specified explicitly in
335 // the library, returns the gradient function name. Otherwise,
336 // returns an empty string.
337 string FindGradient(const string& func) const;
338
339 // OpRegistryInterface method. Useful for constructing a Graph.
340 //
341 // If "op" is defined in the library, returns its signature.
342 // Otherwise, assume "op" is a primitive op and returns its op
343 // signature and shape inference function.
344 Status LookUp(const string& op_type_name,
345 const OpRegistrationData** op_reg_data) const override;
346
347 static constexpr const char* const kGradientOp = "SymbolicGradient";
348 static constexpr const char* const kFuncAttr = "f";
349
350 // Given a node def 'ndef', inspects attributes of the callee
351 // function to derive the attribute 'value' for 'attr'. Returns OK
352 // iff the attribute is given by the function's definition.
353 // TODO(irving): Remove; keep only the const Node& version.
354 template <typename T>
355 Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
356
357 // Given a node, inspects attributes of the callee function to derive the
358 // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
359 // function's definition.
360 template <typename T>
361 Status GetAttr(const Node& node, const string& attr, T* value) const;
362
363 // Returns a proto representation of the state of this function library.
364 FunctionDefLibrary ToProto() const;
365
num_functions()366 size_t num_functions() const { return function_defs_.size(); }
367
default_registry()368 const OpRegistryInterface* default_registry() const {
369 return default_registry_;
370 }
371
372 private:
373 // Shape inference for functions is handled separately by ShapeRefiner.
374
375 struct FunctionDefAndOpRegistration {
376 FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
377
378 FunctionDef fdef;
379 OpRegistrationData op_registration_data;
380 };
381
382 // Same as AddFunctionDef/AddGradientDef except these methods set
383 // `added` to true if the `fdef`/`grad` were actually added to this.
384 Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added);
385 Status AddGradientDefHelper(const GradientDef& grad, bool* added);
386
387 const OpRegistryInterface* const default_registry_;
388 gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>>
389 function_defs_;
390 gtl::FlatMap<string, string> func_grad_;
391
392 // Helper function for GetAttr. Returns the FunctionDef* to get the
393 // attr from.
394 const FunctionDef* GetAttrImpl(const NodeDef& ndef) const;
395
396 // Remove all functions in `funcs` and all gradients of
397 // functions in `funcs_with_grads` from this library.
398 void Remove(const std::vector<string>& funcs,
399 const std::vector<string>& funcs_with_grads);
400 };
401
402 // Forward declare. Defined in common_runtime/function.h
403 struct FunctionBody;
404
405 // Forward declare. Defined in common_runtime/device.h
406 class Device;
407
408 class FunctionLibraryRuntime {
409 public:
~FunctionLibraryRuntime()410 virtual ~FunctionLibraryRuntime() {}
411
412 // Instantiate a function with the given "attrs".
413 //
414 // Returns OK and fills in "handle" if the instantiation succeeds.
415 // Otherwise returns an error and "handle" is undefined.
416 struct InstantiateOptions {
417 // The canonical device name of the device on which the function
418 // should be instantiated. If empty, the function will be
419 // instantiated on the local device.
420 string target;
421
422 // This interface is EXPERIMENTAL and subject to change.
423 //
424 // If non-null, the runtime will use `overlay_lib` to resolve
425 // function(s) named in `function_name` and `attrs`. Otherwise,
426 // the runtime will use its internal library.
427 // NOTE(mrry): If provided, all functions defined in `overlay_lib`
428 // must be self-contained, and cannot refer to functions defined
429 // in other libraries.
430 // TODO(mrry): Provide a mechanism for sharing core functions
431 // between a set of libraries (e.g. by allowing a
432 // `FunctionLibraryDefinition` to store an `outer_scope` pointer
433 // and implementing name resolution across libraries).
434 const FunctionLibraryDefinition* overlay_lib = nullptr;
435
436 // This interface is EXPERIMENTAL and subject to change.
437 //
438 // If non-empty, the runtime will use `state_handle` to identify
439 // cached state related the instantiated function. Two functions
440 // of the same name and attrs, instantiated with the same
441 // `state_handle` will have the same handle and share the same
442 // state (in stateful kernels); and two functions with different
443 // values for `state_handle` will have independent state.
444 string state_handle;
445 };
446 typedef uint64 Handle;
447 virtual Status Instantiate(const string& function_name, AttrSlice attrs,
448 const InstantiateOptions& options,
449 Handle* handle) = 0;
Instantiate(const string & function_name,AttrSlice attrs,Handle * handle)450 Status Instantiate(const string& function_name, AttrSlice attrs,
451 Handle* handle) {
452 return Instantiate(function_name, attrs, {}, handle);
453 }
454
455 // Releases state associated with the handle.
456 virtual Status ReleaseHandle(Handle handle) = 0;
457
458 // Returns the function body for the instantiated function given its
459 // handle 'h'. Returns nullptr if "h" is not found.
460 //
461 // *this keeps the ownership of the returned object, which remains alive
462 // as long as *this.
463 virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
464
465 // Asynchronously invokes the instantiated function identified by
466 // "handle".
467 //
468 // If function execution succeeds, "done" is called with OK and
469 // "*rets" is filled with the function's return values. Otheriwse,
470 // "done" is called with an error status.
471 //
472 // Does not take ownership of "rets".
473 // In the cross-process scenario, runner isn't used for making the Async
474 // RPC calls.
475 struct Options {
476 // The id of the step that is calling this function.
477 int64 step_id = 0;
478 Rendezvous* rendezvous = nullptr;
479 CancellationManager* cancellation_manager = nullptr;
480 ScopedStepContainer* step_container = nullptr;
481 StepStatsCollector* stats_collector = nullptr;
482
483 std::function<void(std::function<void()>)>* runner = nullptr;
484
485 // Parameters for remote function execution.
486 bool remote_execution = false;
487 string source_device = ""; // Fully specified device name.
488
489 // Allocator attributes specifying where the args are / rets should be put.
490 // These should either be {} or match the length of args / retvals. If {},
491 // the default allocator attributes will be assumed for all args / retvals.
492 std::vector<AllocatorAttributes> args_alloc_attrs;
493 std::vector<AllocatorAttributes> rets_alloc_attrs;
494
495 // If true, we create a new IntraProcessRendezvous, else use the existing
496 // one.
497 bool create_rendezvous = false;
498 };
499 typedef std::function<void(const Status&)> DoneCallback;
500 virtual void Run(const Options& opts, Handle handle,
501 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
502 DoneCallback done) = 0;
503 virtual void Run(const Options& opts, Handle handle,
504 CallFrameInterface* call_frame, DoneCallback done) = 0;
505
506 // Creates a "kernel" for the given node def "ndef".
507 //
508 // If succeeds, returns OK and the caller takes the ownership of the
509 // returned "*kernel". Otherwise, returns an error.
510 virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0;
511
512 // Returns true iff the function named `function_name` is stateful.
513 // NOTE(mrry): This method assumes that the runtime is associated with a
514 // default function library, and looks up `function_name` in that library.
515 // It does not support overlay libraries.
516 virtual bool IsStateful(const string& function_name) = 0;
517
518 // Returns the device on which the function executes.
519 virtual Device* device() = 0;
520
521 // Returns the function library definition that backs this runtime.
522 // NOTE(mrry): The returned library definition is the default function library
523 // for this runtime. The runtime may instantiate functions from separate
524 // overlay libraries, which are not returned by this function.
525 virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
526 const = 0;
527
528 // Returns the environment on which the function executes.
529 virtual Env* env() = 0;
530
531 // Returns a debug string showing the definition of the function of
532 // 'handle'.
533 virtual string DebugString(Handle handle) = 0;
534
535 // Returns the graph version number.
536 virtual int graph_def_version() = 0;
537
538 typedef uint64 LocalHandle;
539
540 virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
541 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
542 FunctionLibraryRuntime** out_flr) = 0;
543 };
544
545 // Returns a canonicalized string for the instantiation of the
546 // function of the given "name", attributes "attrs", and "options".
547 //
548 // The returned string is guaranteed to be stable within one address
549 // space. But it may be change as the implementation
550 // evolves. Therefore, it should not be persisted or compared across
551 // address spaces.
552 string Canonicalize(const string& funcname, AttrSlice attrs,
553 const FunctionLibraryRuntime::InstantiateOptions& options);
Canonicalize(const string & funcname,AttrSlice attrs)554 inline string Canonicalize(const string& funcname, AttrSlice attrs) {
555 return Canonicalize(funcname, attrs, {});
556 }
557
558 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
559 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
560 typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
561 std::unique_ptr<OpKernel>*)>
562 CustomKernelCreator;
563
564 // Used to instantiate and run functions in a distributed system.
565 class DistributedFunctionLibraryRuntime {
566 public:
~DistributedFunctionLibraryRuntime()567 virtual ~DistributedFunctionLibraryRuntime() {}
568
569 // The _target attr in attrs determines where the function is instantiated.
570 virtual Status Instantiate(
571 const string& function_name, const FunctionLibraryDefinition& lib_def,
572 AttrSlice attrs,
573 const FunctionLibraryRuntime::InstantiateOptions& options,
574 FunctionLibraryRuntime::LocalHandle* handle) = 0;
575
576 // opts.runner isn't used for execution.
577 virtual void Run(const FunctionLibraryRuntime::Options& opts,
578 FunctionLibraryRuntime::LocalHandle handle,
579 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
580 FunctionLibraryRuntime::DoneCallback done) = 0;
581 };
582
583 // Extracts the actual type from "attr_values" based on its definition
584 // "arg_def".
585 //
586 // If "arg_def" is a N*T type, *is_type_list is set to false, and
587 // *dtypes is set to be a vector of size N and each element is T.
588 //
589 // If "arg_def" is a list(type), *is_type_list is set to true, and
590 // *dtypes is set to be a vector of types specified in attrs for
591 // arg_def.
592 //
593 // Otherwise (arg_def is a simple type T), *is_type_list is set to
594 // false, and *dtypes is set to a single element vector, whose only
595 // element is T.
596 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
597 bool* is_type_list, DataTypeVector* dtypes);
598
599 // To register a gradient function for a builtin op, one should use
600 // REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
601 //
602 // Typically, the c++ grad factory is a plan function that can be
603 // converted into ::tensorflow::gradient::Creator, which is
604 // std::function<Status(const AttrSlice&, FunctionDef*)>.
605 //
606 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
607 // definition of a brain function which compute the gradient for the
608 // <op_name> when the <op_name> is instantiated with the given attrs.
609 //
610 // E.g.,
611 //
612 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
613 // bool transpose_a;
614 // TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
615 // bool transpose_b;
616 // TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
617 // DataType dtype;
618 // TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
619 // if (!transpose_a && !transpose_b) {
620 // *g = FunctionDefHelper::Define(
621 // "MatMulGrad",
622 // {"x:T ", "y:T", "dz:T"}, // Inputs to this function
623 // {"dx:T", "dy:T"}, // Outputs from this function
624 // {"T: {float, double}"}, // Attributes needed by this function
625 // {
626 // {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
627 // {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
628 // {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
629 // {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
630 // });
631 // } else {
632 // ... ...
633 // }
634 // return Status::OK();
635 // }
636 //
637 // NOTE: $T is substituted with the type variable "T" when the
638 // gradient function MatMul is instantiated.
639 //
640 // TODO(zhifengc): Better documentation somewhere.
641
642 // Macros to define a gradient function factory for a primitive
643 // operation.
644 #define REGISTER_OP_GRADIENT(name, fn) \
645 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
646
647 #define REGISTER_OP_NO_GRADIENT(name) \
648 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
649
650 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
651 REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
652
653 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
654 static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \
655 ::tensorflow::gradient::RegisterOp(name, fn)
656
657 namespace gradient {
658 // Register a gradient creator for the "op".
659 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
660 bool RegisterOp(const string& op, Creator func);
661
662 // Returns OK the gradient creator for the "op" is found (may be
663 // nullptr if REGISTER_OP_NO_GRADIENT is used.
664 Status GetOpGradientCreator(const string& op, Creator* creator);
665 }; // namespace gradient
666
667 // Declare explicit instantiations of GetAttr
668 #define GET_ATTR(T) \
669 extern template Status FunctionLibraryDefinition::GetAttr( \
670 const Node&, const string&, T*) const; \
671 extern template Status FunctionLibraryDefinition::GetAttr( \
672 const NodeDef&, const string&, T*) const;
673 GET_ATTR(string)
674 GET_ATTR(bool)
675 #undef GET_ATTR
676
677 } // end namespace tensorflow
678
679 #endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_
680