1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
18
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/platform/platform.h"
24 // clang-format on
25
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/variant.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/registration/registration.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/lib/gtl/flatmap.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/protobuf.h"
44 #include "tensorflow/core/platform/threadpool_interface.h"
45 #include "tensorflow/core/protobuf/config.pb.h"
46 #if !defined(IS_MOBILE_PLATFORM)
47 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
48 #endif // IS_MOBILE_PLATFORM
49
50 namespace tensorflow {
51
52 class CancellationManager;
53 class CollectiveExecutor;
54 class DeviceSet;
55 class Graph;
56 class GraphDef;
57 class OpKernel;
58 class ProcessFunctionLibraryRuntime;
59 class ResourceMgr;
60 class Rendezvous;
61 class ScopedStepContainer;
62 class StepStatsCollectorInterface;
63 class Node;
64
65 // FunctionDefHelper::Create is a convenient helper to construct a
66 // FunctionDef proto.
67 // E.g.,
68 // FunctionDef my_func = FunctionDefHelper::Create(
69 // "my_func_name",
70 // {"x:T", "y:T" /* one string per argument */},
71 // {"z:T" /* one string per return value */},
72 // {"T: {float, double}" /* one string per attribute */},
73 // {
74 // {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
75 // /* one entry per function node */
76 // },
77 // /* Mapping between function returns and function node outputs. */
78 // {{"z", "o:z"}});
79 //
80 // For the old Function::Node approach, use FunctionDefHelper::Define()
81 // E.g.,
82 // FunctionDef my_func = FunctionDefHelper::Define(
83 // "my_func_name",
84 // {"x:T", "y:T" /* one string per argument */},
85 // {"z:T" /* one string per return value */},
86 // {"T: {float, double}" /* one string per attribute */},
87 // {
88 // {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
89 // /* one entry per function node */
90 // });
91 class FunctionDefHelper {
92 public:
93 // AttrValueWrapper has copy constructors for the type T so that
94 // it's easy to construct a simple AttrValue proto.
95 //
96 // If T is a string type (const char*, string, or StringPiece), and
97 // it starts with "$", we construct a AttrValue of "placeholder".
98 //
99 // E.g.,
100 // std::<string, AttrValueWrapper> x = {"T", "$T"}
101 // is a named attr value placeholder.
102 struct AttrValueWrapper {
103 AttrValue proto;
104
AttrValueWrapperAttrValueWrapper105 AttrValueWrapper() {}
106
107 template <typename T>
AttrValueWrapperAttrValueWrapper108 AttrValueWrapper(T val) { // NOLINT(runtime/explicit)
109 SetAttrValue(val, &proto);
110 }
111
112 private:
113 void InitFromString(StringPiece val);
114 };
115
116 // Constructs an AttrValue.func given the "name" and "attrs".
117 static AttrValueWrapper FunctionRef(
118 const std::string& name,
119 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
FunctionRef(const std::string & name)120 static AttrValueWrapper FunctionRef(const std::string& name) {
121 return FunctionRef(name, {});
122 }
123
124 // Node is used to construct FunctionDef.Node using initialization
125 // lists. E.g.,
126 // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y
127 //
128 // If the op has no inputs, then name is be specified.
129 // Node n = {{}, "AssignVariable", {"resource", "val"}, {{"dtype",
130 // "DT_FLOAT"},
131 // {"update0"}, "CPU:0", "update1"}}
132 struct Node {
133 // When constructing a NodeDef, the first entry in ret is used as
134 // the node name, the remaining values are ignored.
135 std::vector<string> ret;
136 std::string op;
137 std::vector<string> arg;
138 std::vector<std::pair<string, AttrValueWrapper>> attr;
139 std::vector<string> dep;
140 std::string device;
141
142 // Required if the op has zero outputs. Otherwise, ret[0] used as name if
143 // name is left empty.
144 std::string name;
145
GetNameNode146 std::string GetName() const {
147 if (!name.empty()) return name;
148 CHECK(!ret.empty());
149 return ret[0];
150 }
151 std::vector<string> original_node_names;
152 std::vector<string> original_func_names;
153
154 NodeDef ToNodeDef() const;
155 };
156
157 // Creates a FunctionDef from the given parameters. Node inputs must use
158 // function encoding (node_name:output_name[:output_index]).
159 // - `ret_def` holds a mapping from the function output names from `out_def`
160 // to the node outputs from `node_def`.
161 // - `control_ret_def` holds a mapping from the function control
162 // output names to the nodes from `node_def`.
163 static FunctionDef Create(
164 const std::string& function_name, gtl::ArraySlice<string> in_def,
165 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
166 gtl::ArraySlice<Node> node_def,
167 gtl::ArraySlice<std::pair<string, string>> ret_def,
168 gtl::ArraySlice<std::pair<string, string>> control_ret_def);
169
170 // Creates a FunctionDef from the given parameters. Node inputs must use
171 // function encoding (node_name:output_name[:output_index]).
172 // - `ret_def` holds a mapping from the function output names from `out_def`
173 // to the node outputs from `node_def`.
174 static FunctionDef Create(const std::string& function_name,
175 gtl::ArraySlice<string> in_def,
176 gtl::ArraySlice<string> out_def,
177 gtl::ArraySlice<string> attr_def,
178 gtl::ArraySlice<Node> node_def,
179 gtl::ArraySlice<std::pair<string, string>> ret_def);
180
181 // TODO(josh11b): Get rid of these and transition to the one above.
182 static FunctionDef Define(const std::string& function_name,
183 gtl::ArraySlice<string> arg_def,
184 gtl::ArraySlice<string> ret_def,
185 gtl::ArraySlice<string> attr_def,
186 gtl::ArraySlice<Node> node_def);
187
188 // Defines an anonymous function. I.e., its name is not relevant.
189 static FunctionDef Define(gtl::ArraySlice<string> arg_def,
190 gtl::ArraySlice<string> ret_def,
191 gtl::ArraySlice<string> attr_def,
192 gtl::ArraySlice<Node> node_def);
193
194 // Helpers to construct a constant scalar.
195 template <typename T>
Const(const std::string & name,const T & val)196 static Node Const(const std::string& name, const T& val) {
197 Node n = {{name}, "Const"};
198 const DataType dtype = DataTypeToEnum<T>::value;
199 n.attr.push_back({"dtype", dtype});
200 Tensor t(dtype, TensorShape({}));
201 t.scalar<T>()() = val;
202 n.attr.push_back({"value", t});
203 return n;
204 }
205
206 template <typename T>
Const(const std::string & name,gtl::ArraySlice<T> vals)207 static Node Const(const std::string& name, gtl::ArraySlice<T> vals) {
208 Node n = {{name}, "Const"};
209 const DataType dtype = DataTypeToEnum<T>::value;
210 n.attr.push_back({"dtype", dtype});
211 int64_t num = vals.size();
212 Tensor t(dtype, TensorShape({num}));
213 for (size_t i = 0; i < vals.size(); ++i) {
214 t.flat<T>()(i) = vals[i];
215 }
216 n.attr.push_back({"value", t});
217 return n;
218 }
219 };
220
221 template <>
AttrValueWrapper(const char * val)222 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
223 InitFromString(val);
224 }
225
226 template <>
AttrValueWrapper(const std::string & val)227 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
228 const std::string& val) {
229 InitFromString(val);
230 }
231
232 template <>
AttrValueWrapper(StringPiece val)233 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
234 InitFromString(val);
235 }
236
237 // Instantiate a function.
238 //
239 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
240 // containing placeholders. InstantiateFunction binds these
241 // placeholders and produces an instantiated function encoded in
242 // "result.gdef". The value to substitute a placeholder is given by
243 // "attr_values", which is a map from a placeholder name to an attr
244 // value.
245 //
246 // InstantiateFunction calls "get_function" to find signatures of other
247 // functions and primitive ops.
248
249 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
250 // and opdef is filled with a pointer to the corresponding signature
251 // (a OpDef proto). Otherwise, returns an error.
252 typedef std::function<Status(const string&, const OpDef**)>
253 GetFunctionSignature;
254
255 struct InstantiationResult {
256 DataTypeVector arg_types;
257 DataTypeVector ret_types;
258 std::vector<NodeDef> nodes;
259 };
260 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
261 GetFunctionSignature get_function,
262 InstantiationResult* result);
263
264 // Returns a debug string for a function definition.
265 //
266 // The returned text is multiple-line. It is intended to be
267 // human-readable rather than being friendly to parsers. It is _NOT_
268 // intended to be the canonical string representation of "func_def".
269 // Particularly, it may not include all information presented in
270 // "func_def" (e.g., comments, description of the function arguments,
271 // etc.)
272 std::string DebugString(const FunctionDef& func_def);
273 std::string DebugString(const GraphDef& instantiated_func_def);
274 std::string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
275
276 // Returns a debug string for a top level graph (the main program and
277 // its supporting functions defined in its library).
278 std::string DebugStringWhole(const GraphDef& gdef);
279
280 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
281 // of NodeDefs doesn't matter.
282 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
283
284 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
285 // In other words, if two fdefs compare equal, their hash values will be the
286 // same.
287 uint64 FunctionDefHash(const FunctionDef& fdef);
288
289 class CallFrameInterface {
290 public:
~CallFrameInterface()291 virtual ~CallFrameInterface() {}
292
293 virtual size_t num_args() const = 0;
294 virtual size_t num_retvals() const = 0;
295
296 virtual Status GetArg(int index, const Tensor** val) = 0;
297
298 // Optimized implementation of `GetArg()` that allows the caller to take
299 // ownership of the tensor. This method may only be called once per
300 // value of `index` and `CallFrameInterface` instance.
301 //
302 // REQUIRES: `this->CanConsumeArg(index) == true`.
ConsumeArg(int index,Tensor * val)303 virtual void ConsumeArg(int index, Tensor* val) {
304 LOG(ERROR) << "This `CallFrameInterface` implementation does not support "
305 "consuming arguments.";
306 }
CanConsumeArg(int index)307 virtual bool CanConsumeArg(int index) const { return false; }
308
309 virtual Status SetRetval(int index, const Tensor& val) = 0;
310 };
311
312 // Represents a function call frame. I.e., the data structure used to
313 // pass arguments to a function and retrieve its results.
314 //
315 // Runtime must arrange accesses to one FunctionCallFrame s.t.
316 // 1. SetArgs() happens before any GetArg();
317 // 2. GetRetvals happens after all SetRetval();
318 class FunctionCallFrame : public CallFrameInterface {
319 public:
320 FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
321 ~FunctionCallFrame() override;
322
323 // Caller methods.
324 Status SetArgs(gtl::ArraySlice<Tensor> args);
325 Status GetRetvals(std::vector<Tensor>* rets) const;
326
327 // Moves the return values from the frame to rets. If allow_dead_tensors is
328 // false it will fail if any of the retvals do not have a value.
329 Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors);
330
num_args()331 size_t num_args() const override { return arg_types_.size(); }
num_retvals()332 size_t num_retvals() const override { return ret_types_.size(); }
333
334 // Callee methods.
335 Status GetArg(int index, const Tensor** val) override;
336 Status SetRetval(int index, const Tensor& val) override;
337
338 private:
339 DataTypeVector arg_types_;
340 DataTypeVector ret_types_;
341 gtl::InlinedVector<Tensor, 4> args_;
342 struct Retval {
343 bool has_val = false;
344 Tensor val;
345 };
346 gtl::InlinedVector<Retval, 4> rets_;
347
348 TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
349 };
350
351 // Language agnostic stack traces.
352 class AbstractStackTrace {
353 public:
354 struct TracePrintingOptions {
355 // Show inline the contents of each stack line.
356 bool show_line_contents = false;
357
358 // Drop the common largest prefix of all filenames in stack frames.
359 bool filter_common_prefix = false;
360
361 // Do not show internal frames.
362 bool drop_internal_frames = false;
363 };
364
~AbstractStackTrace()365 virtual ~AbstractStackTrace() {}
366
367 // The returned span is alive as long as the AbstractStackTrace is alive.
368 virtual absl::Span<StackFrame const> ToFrames() const = 0;
369
370 // Returns the last stack frame from user code, attempting to ignore the
371 // framework code. Returns an empty frame if no such stack frame was found.
372 virtual StackFrame LastUserFrame() const = 0;
373 virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
374 };
375
376 using StackTracesMap =
377 std::unordered_map<std::string,
378 std::shared_ptr<tensorflow::AbstractStackTrace>>;
379
380 // Helper to maintain a map between function names in a given
381 // FunctionDefLibrary and function definitions.
382 //
383 // This class is thread-safe.
384 class FunctionLibraryDefinition : public OpRegistryInterface {
385 public:
386 // Ops created for function arguments bear the name given by `kArgOp`; those
387 // created for return values bear the name given by `kRetOp`.
388 static constexpr const char* const kArgOp = "_Arg";
389 static constexpr const char* const kDeviceArgOp = "_DeviceArg";
390 static constexpr const char* const kRetOp = "_Retval";
391 static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
392 static constexpr const char* const kIntsOnDeviceAttr =
393 "experimental_ints_on_device";
394 static constexpr const char* const kSharedRendezvousAttr =
395 "shared_rendezvous";
396
397 static constexpr const char* const kGradientOp = "SymbolicGradient";
398 static constexpr const char* const kFuncAttr = "f";
399
400 // Note: This constructor grabs `lib_def`'s lock in shared mode.
401 FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
402 FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
403 const FunctionDefLibrary& lib_def = {});
404 ~FunctionLibraryDefinition() override;
405
406 FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
407 delete;
408
409 // Returns True if the library contains `func`, False otherwise.
410 bool Contains(const std::string& func) const;
411
412 // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
413 // returns its definition proto.
414 //
415 // NB: This function returns a borrowed pointer, which can be invalidated by a
416 // subsequent call to `ReplaceFunction()` with the given name.
417 const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_);
418
419 // Adds function definition 'fdef' to this function library.
420 // Returns status 'ok' on success, or error otherwise. This is a no-op if
421 // 'fdef' already exists in this function library.
422 // If 'fdef' is successfully added to the library, it will be accessible
423 // from 'LookUp' and included in the proto returned by 'ToProto'.
424 // This operation is atomic.
425 //
426 // Associates `graph` with a function `func_name`. Lifetime assumption:
427 // `graph` has to outlive all instantiated graphs.
428 Status AddFunctionDef(const FunctionDef& fdef,
429 const StackTracesMap& stack_traces = {})
430 TF_LOCKS_EXCLUDED(mu_);
431
432 // Adds gradient definition 'grad' to this function library.
433 // This is a no-op if 'grad' already exists in this function library.
434 // If 'grad' is successfully added, it will be accessible via 'FindGradient'
435 // and included in the proto returned by 'ToProto'.
436 // This operation is atomic.
437 Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
438
439 // Replaces the function corresponding to `func` with `fdef`. Returns
440 // a non-OK status if "func" was not found in the library, OK otherwise.
441 // Please be careful when replacing function: make sure all previous pointers
442 // returned by `Find()` are no longer in use.
443 Status ReplaceFunction(const std::string& func, const FunctionDef& fdef,
444 const StackTracesMap& stack_traces = {})
445 TF_LOCKS_EXCLUDED(mu_);
446
447 // Replaces the gradient corresponding to `grad.function_name()`. Returns
448 // a non-OK status if "grad.function_name()" was not found in the library, OK
449 // otherwise.
450 Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
451
452 // Removes the function corresponding to 'func'. Returns a non-OK status if
453 // 'func' was not found in the library, OK otherwise.
454 // Please be careful when removing function: make sure there are no other
455 // nodes using the function, and all previous pointers returned by `Find()`
456 // are no longer in use.
457 Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
458
459 // Removes all the functions and gradient functions.
460 void Clear() TF_LOCKS_EXCLUDED(mu_);
461
462 // Adds the functions and gradients in 'other' to this function library.
463 // Duplicate functions and gradients are ignored.
464 // This operation is atomic.
465 Status AddLibrary(const FunctionLibraryDefinition& other)
466 TF_LOCKS_EXCLUDED(mu_);
467
468 // Adds the functions and gradients in 'lib_def' to this function library.
469 // Duplicate functions and gradients are ignored.
470 // This operation is atomic.
471 Status AddLibrary(const FunctionDefLibrary& lib_def) TF_LOCKS_EXCLUDED(mu_);
472
473 // If the gradient function for 'func' is specified explicitly in
474 // the library, returns the gradient function name. Otherwise,
475 // returns an empty string.
476 std::string FindGradient(const std::string& func) const
477 TF_LOCKS_EXCLUDED(mu_);
478
479 // OpRegistryInterface method. Useful for constructing a Graph.
480 //
481 // If "op" is defined in the library, returns its signature.
482 // Otherwise, assume "op" is a primitive op and returns its op
483 // signature and shape inference function.
484 //
485 // NB: This function outputs a borrowed pointer, which can be invalidated by a
486 // subsequent call to `ReplaceFunction()` with the given name.
487 Status LookUp(const std::string& op_type_name,
488 const OpRegistrationData** op_reg_data) const override
489 TF_LOCKS_EXCLUDED(mu_);
490
491 // Generates new function name with the specified prefix that is unique
492 // across this library.
493 std::string UniqueFunctionName(StringPiece prefix) const
494 TF_LOCKS_EXCLUDED(mu_);
495
496 // Given a node def 'ndef', inspects attributes of the callee
497 // function to derive the attribute 'value' for 'attr'. Returns OK
498 // iff the attribute is given by the function's definition.
499 // TODO(irving): Remove; keep only the const Node& version.
500 template <typename T>
501 Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const;
502
503 // Given a node, inspects attributes of the callee function to derive the
504 // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
505 // function's definition.
506 template <typename T>
507 Status GetAttr(const Node& node, const std::string& attr, T* value) const;
508
509 // Returns a proto representation of the state of this function library.
510 FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_);
511
num_functions()512 size_t num_functions() const {
513 tf_shared_lock l(mu_);
514 return function_defs_.size();
515 }
516
517 // Returns all the function names in the FunctionLibraryDefinition.
518 std::vector<string> ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_);
519
default_registry()520 const OpRegistryInterface* default_registry() const {
521 return default_registry_;
522 }
set_default_registry(const OpRegistryInterface * registry)523 void set_default_registry(const OpRegistryInterface* registry) {
524 default_registry_ = registry;
525 }
526
527 // Returns a copy of `*this` with only the subset of functions that are
528 // reachable from the nodes of `graph` or `func`.
529 FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
530 FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
531
532 // Copies the function named `func` from `other` to this
533 // FunctionLibraryDefinition.
534 // REQUIRES: `this->default_registry() == other.default_registry()`.
535 // Returns OK on success, or error otherwise. This is a no-op if a function
536 // name `func` already exists in this function library, and has the same
537 // implementation as in `other`. If the implementations conflict, an invalid
538 // argument error is returned.
539 Status CopyFunctionDefFrom(const std::string& func,
540 const FunctionLibraryDefinition& other)
541 TF_LOCKS_EXCLUDED(mu_);
542
543 // Returns graph with debug stack traces for the given function, or `nullptr`
544 // if none found.
GetStackTraces(const std::string & func_name)545 const StackTracesMap& GetStackTraces(const std::string& func_name) const {
546 tf_shared_lock l(mu_);
547 std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name);
548 if (entry) {
549 return entry->stack_traces;
550 }
551 static const auto* empty_map = new StackTracesMap;
552 return *empty_map;
553 }
554
555 private:
556 // Shape inference for functions is handled separately by ShapeRefiner.
557
558 struct FunctionDefAndOpRegistration {
559 explicit FunctionDefAndOpRegistration(
560 const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {});
561
562 const FunctionDef fdef;
563 const OpRegistrationData op_registration_data;
564 const StackTracesMap stack_traces;
565 };
566
567 std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
568 const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_);
569 std::string FindGradientHelper(const std::string& func) const
570 TF_SHARED_LOCKS_REQUIRED(mu_);
571
572 Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
573 bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
574
575 // Same as AddFunctionDef/AddGradientDef except these methods set
576 // `added` to true if the `fdef`/`grad` were actually added to this.
577 Status AddFunctionDefHelper(const FunctionDef& fdef,
578 const StackTracesMap& stack_traces, bool* added)
579 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
580 Status AddGradientDefHelper(const GradientDef& grad, bool* added)
581 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
582
583 // Helper function for GetAttr. Returns the FunctionDef* to get the
584 // attr from.
585 const FunctionDef* GetAttrImpl(const NodeDef& ndef) const
586 TF_LOCKS_EXCLUDED(mu_);
587
588 // Remove all functions in `funcs` and all gradients of functions in
589 // `funcs_with_grads` from this library.
590 Status Remove(const std::vector<string>& funcs,
591 const std::vector<string>& funcs_with_grads)
592 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
593
594 // Remove `func` from the library. Returns non-OK Status unless `func` is in
595 // the library. This should only be called when there is a guarantee that the
596 // function being removed hasn't been retrieved with `Find`.
597 Status RemoveFunctionHelper(const std::string& func)
598 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
599
600 // Remove gradient of function `func` from the library. Returns non-OK Status
601 // unless `func` has a gradient.
602 Status RemoveGradient(const std::string& func)
603 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
604
605 mutable mutex mu_;
606 const OpRegistryInterface* default_registry_;
607 gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>>
608 function_defs_ TF_GUARDED_BY(mu_);
609 gtl::FlatMap<string, string> func_grad_ TF_GUARDED_BY(mu_);
610 };
611
612 // Forward declare. Defined in common_runtime/function.h
613 struct FunctionBody;
614
615 // Forward declare. Defined in common_runtime/device.h
616 class Device;
617 // Forward declare. Defined in common_runtime/device_mgr.h
618 class DeviceMgr;
619
620 // Index of an _Arg node.
621 struct FunctionArgIndex {
FunctionArgIndexFunctionArgIndex622 explicit FunctionArgIndex(const int index) : index(index) {}
FunctionArgIndexFunctionArgIndex623 FunctionArgIndex(const int index, const int sub_index)
624 : index(index), sub_index(sub_index) {}
625
626 // The value of the attribute "Index" of the _Arg node.
627 int index;
628 // Set only when the _Arg node represents multiple arguments (e.g. an _Arg
629 // node is replicated to multiple devices/subgraphs). Use sub-index to
630 // distinguish arguments with the same index.
631 int sub_index = -1;
632 };
633
634 class FunctionLibraryRuntime {
635 public:
~FunctionLibraryRuntime()636 virtual ~FunctionLibraryRuntime() {}
637
638 // Instantiate a function with the given "attrs".
639 //
640 // Returns OK and fills in "handle" if the instantiation succeeds.
641 // Otherwise returns an error and "handle" is undefined.
642 struct InstantiateOptions {
643 // The canonical device name of the device on which the function
644 // should be instantiated. If empty, the function will be
645 // instantiated on the local device.
646 std::string target;
647
648 // Should the function be instantiated as a multi-device function?
649 bool is_multi_device_function = false;
650
651 // If true, graph passes will be skipped when instantiating the function
652 // since they have already run on the main function side.
653 bool is_component_function = false;
654
655 // For multi-device functions, a vector of canonical device names for
656 // function's inputs. The device of resource inputs must be the device
657 // backing the resource, not the CPU device backing the resource handle.
658 // Must have the same length as number of inputs to the function.
659 std::vector<string> input_devices;
660
661 // For multi-device functions, a vector of canonical device names for
662 // function's outputs.
663 //
664 // (a) If specified (must have the same length as number of outputs):
665 //
666 // Specified devices will be assigned to Retval nodes inserted into the
667 // function body graph in place of function outputs. It is allowed to
668 // specify output device as empty string, in this case Retval device
669 // assignment will be inferred later when function graph will be placed
670 // before partitioning (this is required for resource outputs). Placer will
671 // respect colocation constraints.
672 //
673 // (b) If not specified:
674 //
675 // Function runtime will infer Retval device by following input edges, until
676 // it will reach a node with a device specification. This device
677 // specification must identify a unique device, i.e. a general specification
678 // like "job:foo" matching multiple devices will result in an error.
679 //
680 // IMPORTANT: Resource outputs
681 //
682 // Multi device functions might return resources on a devices different from
683 // the function call device. If output device is not specified for the
684 // resource output, and node producing that resource is a function call,
685 // runtime will leave device specification empty and will rely on Placer to
686 // infer correct device.
687 std::vector<string> output_devices;
688
689 // If set, it indicates the original output indices of a component function.
690 absl::optional<std::vector<int>> ret_indices = absl::nullopt;
691
692 // Maps from a CompositeDevice name to a list of underlying physical
693 // devices.
694 absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
695
696 // This interface is EXPERIMENTAL and subject to change.
697 //
698 // For multi-device functions, a mapping from _Arg node index to type and
699 // shape for input resources.
700 // REQUIRES: if input_resource_dtypes_and_shapes.count(i) > 0 then i-th
701 // argument type must be DT_RESOURCE.
702 std::unordered_map<int, DtypeAndPartialTensorShape>
703 input_resource_dtypes_and_shapes;
704
705 // This interface is EXPERIMENTAL and subject to change.
706 //
707 // If non-null, the runtime will use `lib_def` to resolve function(s) named
708 // in `function_name` and `attrs`. Otherwise, the runtime will use its
709 // internal library.
710 //
711 // NOTE(mrry): If provided, all functions defined in `lib_def` must be
712 // self-contained, and cannot refer to functions defined in other libraries.
713 const FunctionLibraryDefinition* lib_def = nullptr;
714
715 // This interface is EXPERIMENTAL and subject to change.
716 //
717 // If non-empty, the runtime will use `state_handle` to identify
718 // cached state related the instantiated function. Two functions
719 // of the same name and attrs, instantiated with the same
720 // `state_handle` will have the same handle and share the same
721 // state (in stateful kernels); and two functions with different
722 // values for `state_handle` will have independent state.
723 std::string state_handle;
724
725 // This interface is EXPERIMENTAL and subject to change.
726 //
727 // Instantiates the function using an executor of the given type. If empty,
728 // the default TensorFlow executor will be used.
729 std::string executor_type;
730
731 // If true, the runtime will attempt to create kernels for the function at
732 // instantiation time, rather than on the first run. This can be used to
733 // surface errors earlier.
734 bool create_kernels_eagerly = false;
735
736 // This interface is EXPERIMENTAL and subject to change.
737 //
738 // Instantiates the function with the provided config_proto.
739 ConfigProto config_proto;
740
741 // If provided, this optimization function will be invoked before
742 // the placer for multi-device functions.
743 std::function<Status(std::vector<string> /*ret_node_names*/,
744 std::vector<string> /*keep_node_names*/,
745 FunctionLibraryDefinition*, const DeviceSet&,
746 Device* /*cpu_device*/, std::unique_ptr<Graph>*)>
747 optimize_graph_fn;
748
749 // If set, partitioned functions will be added to `graph_collector`.
750 // `graph_collector` must be alive during the call to Instantiate.
751 GraphCollector* graph_collector = nullptr;
752
753 // Indicates whether the multi-device function backend should default the
754 // placement of ops without request device to `target`.
755 bool default_device_to_target = true;
756
757 // If true, the optimized Graph will be stored so that
758 // `FunctionLibraryRuntime::DebugString(handle)` contains the optimized
759 // Graph. Otherwise, the unoptimized function Graph will be returned.
760 bool include_optimized_graph_in_debug_string = false;
761
762 // If true, the function library runtime cache the function instantiation.
763 bool use_function_cache = false;
764
765 // This interface is EXPERIMENTAL and subject to change.
766 //
767 // If True, allow optimizations which should be targeted at a limited
768 // set of small functions. For example, running kernels synchronously can
769 // be faster under some conditions.
770 bool allow_small_function_optimizations = false;
771
772 // This interface is EXPERIMENTAL and subject to change.
773 //
774 // If True, allow graphs containing control flow nodes to be run on the
775 // single threaded executor.
776 bool allow_control_flow_sync_execution = false;
777
778 // TODO(b/176491312): Remove this if shape inference on import flag is
779 // removed. If True, allows mlir roundtrip to run shape inference on import.
780 bool shape_inference_on_tfe_dialect_import = true;
781
782 // Force int32 _Arg and _Retvals nodes to be left on device instead of
783 // pinning to host.
784 //
785 // Note that we do not pin int32 nodes to host for subgraphs running in
786 // TPU/XLA devices. So this is mainly used to handle the case of multi-CPU
787 // and GPU (non-XLA) graphs.
788 bool int_args_and_retvals_on_device = false;
789
790 // This interface is EXPERIMENTAL and subject to change.
791 //
792 // Instantiates the function for XLA compilation on device_type. If empty,
793 // function is not compiled.
794 std::string xla_compile_device_type;
795 };
796 typedef uint64 Handle;
797 virtual Status Instantiate(const std::string& function_name, AttrSlice attrs,
798 const InstantiateOptions& options,
799 Handle* handle) = 0;
Instantiate(const std::string & function_name,AttrSlice attrs,Handle * handle)800 Status Instantiate(const std::string& function_name, AttrSlice attrs,
801 Handle* handle) {
802 auto opts = absl::make_unique<InstantiateOptions>();
803 return Instantiate(function_name, attrs, *opts, handle);
804 }
805
806 // Releases state associated with the handle.
807 virtual Status ReleaseHandle(Handle handle) = 0;
808
809 // Returns the function body for the instantiated function given its
810 // handle 'h'. Returns nullptr if "h" is not found.
811 //
812 // *this keeps the ownership of the returned object, which remains alive
813 // as long as *this.
814 virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
815
816 // Returns the return types for the function identified by handle `h`.
817 virtual Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0;
818
819 // Asynchronously invokes the instantiated function identified by
820 // "handle".
821 //
822 // If function execution succeeds, "done" is called with OK and
823 // "*rets" is filled with the function's return values. Otherwise,
824 // "done" is called with an error status.
825 //
826 // Does not take ownership of "rets".
827 // In the cross-process scenario, runner isn't used for making the Async
828 // RPC calls.
829 struct Options {
OptionsOptions830 Options() {}
OptionsOptions831 explicit Options(const int64_t step_id) : step_id(step_id) {}
832 // Choose a step ID that is guaranteed not to clash with any
833 // Session-generated step ID. DirectSession only generates
834 // non-negative step IDs (contiguous, starting from 0), and
835 // MasterSession generates 56-bit random step IDs whose MSB is
836 // always 0, so a negative random step ID should suffice.
837 const int64_t step_id = -std::abs(static_cast<int64_t>(random::New64()));
838
839 // op_id of the function running in eager mode. Set when we want to copy
840 // remote outputs lazily. All components of a remote multi-device function
841 // should use the same op_id, in order to correctly map remote output
842 // tensors to the remote TensorHandles in the default device.
843 absl::optional<int64_t> op_id = absl::nullopt;
844
845 RendezvousInterface* rendezvous = nullptr;
846 CancellationManager* cancellation_manager = nullptr;
847 CollectiveExecutor* collective_executor = nullptr;
848 ScopedStepContainer* step_container = nullptr;
849 StepStatsCollectorInterface* stats_collector = nullptr;
850 CoordinationServiceAgent* coordination_service_agent = nullptr;
851
852 absl::optional<ManagedStackTrace> stack_trace = absl::nullopt;
853
854 std::function<void(std::function<void()>)>* runner = nullptr;
855
856 // Parameters for remote function execution.
857 bool remote_execution = false;
858 std::string source_device = ""; // Fully specified device name.
859
860 // Allocator attributes specifying where the args are / rets should be put.
861 // These should either be {} or match the length of args / retvals. If {},
862 // the default allocator attributes will be assumed for all args / retvals.
863 std::vector<AllocatorAttributes> args_alloc_attrs;
864 std::vector<AllocatorAttributes> rets_alloc_attrs;
865
866 // If true, we create a new IntraProcessRendezvous, else use the existing
867 // one.
868 bool create_rendezvous = false;
869
870 // If True, allow returning dead tensors.
871 bool allow_dead_tensors = false;
872
873 // If True, hint that all kernels should be treated as "inexpensive", and
874 // hence executed on the scheduling thread.
875 bool run_all_kernels_inline = false;
876
877 // If not null, use this thread pool for intra op scheduling.
878 thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
879
880 // Returns a human readable representation of this.
881 std::string DebugString() const;
882 };
883 typedef std::function<void(const Status&)> DoneCallback;
884 virtual void Run(const Options& opts, Handle handle,
885 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
886 DoneCallback done) = 0;
887 virtual void Run(const Options& opts, Handle handle,
888 CallFrameInterface* call_frame, DoneCallback done) = 0;
889
890 virtual Status RunSync(Options opts, Handle handle,
891 gtl::ArraySlice<Tensor> args,
892 std::vector<Tensor>* rets) = 0;
893 virtual Status RunSync(Options opts, Handle handle,
894 CallFrameInterface* call_frame) = 0;
895
896 // Creates a "kernel" for the given NodeProperties "props".
897 //
898 // If succeeds, returns OK and the caller takes the ownership of the
899 // returned "*kernel". Otherwise, returns an error.
900 virtual Status CreateKernel(
901 const std::shared_ptr<const NodeProperties>& props,
902 OpKernel** kernel) = 0;
903
904 // Returns true iff the function named `function_name` is stateful.
905 //
906 // NOTE(mrry): This method assumes that the runtime is associated with a
907 // default function library, and looks up `function_name` in that library.
908 // It does not support overriding the function library.
909 virtual bool IsStateful(const std::string& function_name) const = 0;
910
911 // Returns the device on which the function executes.
912 virtual Device* device() = 0;
913 virtual const Device* device() const = 0;
914
915 // Returns the default runner in which the ops should be launched. If the
916 // device on which the function executes has a private thread pool, return
917 // runner on the device local thread pool.
918 virtual std::function<void(std::function<void()>)>* runner() = 0;
919
920 // Get the DeviceMgr from which the device was obtained.
921 virtual const DeviceMgr* device_mgr() const = 0;
922
923 // Returns the function library definition that backs this runtime.
924 //
925 // NOTE(mrry): The returned library definition is the default function library
926 // for this runtime. The caller may override the function library used by the
927 // runtime to instantiate functions, which will not be reflected in the return
928 // value of this function.
929 virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
930 const = 0;
931
932 // Returns the environment on which the function executes.
933 virtual Env* env() = 0;
934
935 // Returns the ConfigProto passed to the session used to create the function.
936 virtual const ConfigProto* const config_proto() = 0;
937
938 // Returns a debug string showing the definition of the function of
939 // 'handle'.
940 virtual std::string DebugString(Handle handle) = 0;
941
942 // Returns the graph version number.
943 virtual int graph_def_version() const = 0;
944
945 typedef uint64 LocalHandle;
946
947 // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to
948 // the caller), FunctionLibraryRuntime (owned by the returned
949 // ProcessFunctionLibraryRuntime), FunctionLibraryDefinition (transferring
950 // ownership to the caller). Note that both the ProcessFunctionLibraryRuntime
951 // and FunctionLibraryRuntime borrow a pointer to the
952 // FunctionLibraryDefinition and so the FunctionLibraryDefinition should
953 // outlive both.
954 //
955 // The `skip_flib_def` argument controls whether the method should clone the
956 // FunctionLibraryDefinition (default behavior) or return an empty function
957 // library. The latter is used by tf.data, which manages
958 // FunctionLibraryDefinitions for its functions independently (and passes
959 // these into the FunctionLibraryRuntime through an overlay), to avoid linear
960 // runtime w.r.t. to number of functions in the current function library.
961 virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
962 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
963 FunctionLibraryRuntime** out_flr,
964 bool skip_flib_def = false) = 0;
965
966 // Returns the name of the executor class (in the sense of
967 // `ExecutorFactory::GetFactory()`) that will be used based on the given
968 // dynamic `options` and static `attrs`. If none is specified, this method
969 // will return an empty string, which leaves the decision up to the runtime.
970 static std::string ExecutorType(const InstantiateOptions& options,
971 AttrSlice attrs);
972 };
973
974 // Returns the device of the `arg_index`-th function input. Update
975 // `composite_devices` if the input device is a composite device.
976 std::string GetFunctionResourceInputDevice(
977 const Tensor& input, const int arg_index, const FunctionDef& function_def,
978 absl::flat_hash_map<string, std::vector<string>>* composite_devices);
979
980 // Returns a canonicalized string for the instantiation of the function of the
981 // given "name", attributes "attrs", and "options".
982 //
983 // The returned string is guaranteed to be stable within one address space. But
984 // it may be change as the implementation evolves. Therefore, it should not be
985 // persisted or compared across address spaces.
986 std::string Canonicalize(
987 const std::string& funcname, AttrSlice attrs,
988 const FunctionLibraryRuntime::InstantiateOptions& options);
989 std::string Canonicalize(const std::string& funcname, AttrSlice attrs);
990
991 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
992 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
993
994 class CustomKernelCreator {
995 public:
~CustomKernelCreator()996 virtual ~CustomKernelCreator() {}
997
998 // Given a NodeDef 'node_def' and the function library runtime 'flr',
999 // validate if the class supports creating such a kernel.
1000 virtual bool CanCreateKernel(
1001 const FunctionLibraryRuntime& flr,
1002 const std::shared_ptr<const NodeProperties>& props) const = 0;
1003
1004 // Given a supported NodeDef, returns a kernel that computes the node.
1005 virtual Status CreateKernel(
1006 FunctionLibraryRuntime* flr,
1007 const std::shared_ptr<const NodeProperties>& props,
1008 std::unique_ptr<OpKernel>* kernel) const = 0;
1009 };
1010
1011 typedef
1012 #if !defined(IS_MOBILE_PLATFORM)
1013 absl::variant<Tensor, eager::RemoteTensorHandle*>
1014 FunctionArg;
1015 #else
1016 absl::variant<Tensor>
1017 FunctionArg;
1018 #endif
1019
1020 // Either a local tensor or the shape of a remote tensor.
1021 typedef absl::variant<Tensor, TensorShape> FunctionRet;
1022
1023 // Used to instantiate and run functions in a distributed system.
1024 class DistributedFunctionLibraryRuntime {
1025 public:
~DistributedFunctionLibraryRuntime()1026 virtual ~DistributedFunctionLibraryRuntime() {}
1027
1028 // Instantiate a function on a remote target specified in `options.target`, by
1029 // sending the name and definition of the function to the remote worker. The
1030 // local `handle` is filled for the instantiated function data and can be used
1031 // for subsequent run function calls on the remote target.
1032 virtual void Instantiate(
1033 const std::string& function_name,
1034 const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
1035 const FunctionLibraryRuntime::InstantiateOptions& options,
1036 FunctionLibraryRuntime::LocalHandle* handle,
1037 FunctionLibraryRuntime::DoneCallback done) = 0;
1038
1039 // Run an instantiated remote function (specified by `handle`) with a list of
1040 // input Tensors in `args` and get its output Tensors in `rets`. The input
1041 // tensor data will be sent with the function execution request, and must be
1042 // available on the current caller side.
1043 // opts.runner isn't used for execution.
1044 virtual void Run(const FunctionLibraryRuntime::Options& opts,
1045 FunctionLibraryRuntime::LocalHandle handle,
1046 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
1047 FunctionLibraryRuntime::DoneCallback done) = 0;
1048
1049 // Run an instantiated remote function (specified by `handle`) with a list of
1050 // input Tensors or RemoteTensorHandles as `args` and get its output Tensors
1051 // or TensorShapes in `rets`. When using RemoteTensorHandles as function
1052 // inputs or TensorShapes as outputs, the corresponding tensor data will be
1053 // resolved on the remote worker, so it is not required to be locally
1054 // available on the caller side. Using RemoteTensorHandle inputs is not
1055 // supported in TensorFlow v1 runtime.
1056 virtual void Run(const FunctionLibraryRuntime::Options& opts,
1057 FunctionLibraryRuntime::LocalHandle handle,
1058 gtl::ArraySlice<FunctionArg> args,
1059 std::vector<FunctionRet>* rets,
1060 FunctionLibraryRuntime::DoneCallback done) = 0;
1061
1062 // Clean up a previously instantiated function on remote worker.
1063 virtual void CleanUp(uint64 step_id,
1064 FunctionLibraryRuntime::LocalHandle handle,
1065 FunctionLibraryRuntime::DoneCallback done) = 0;
1066
1067 // DeviceMgr with *all* available devices (i.e., local and remote).
1068 virtual DeviceMgr* remote_device_mgr() const = 0;
1069 };
1070
1071 // Extracts the actual type from "attr_values" based on its definition
1072 // "arg_def".
1073 //
1074 // If "arg_def" is a N*T type, *is_type_list is set to false, and
1075 // *dtypes is set to be a vector of size N and each element is T.
1076 //
1077 // If "arg_def" is a list(type), *is_type_list is set to true, and
1078 // *dtypes is set to be a vector of types specified in attrs for
1079 // arg_def.
1080 //
1081 // Otherwise (arg_def is a simple type T), *is_type_list is set to
1082 // false, and *dtypes is set to a single element vector, whose only
1083 // element is T.
1084 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
1085 bool* is_type_list, DataTypeVector* dtypes);
1086
1087 // To register a gradient function for a builtin op, one should use
1088 // REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
1089 //
1090 // Typically, the c++ grad factory is a plan function that can be
1091 // converted into ::tensorflow::gradient::Creator, which is
1092 // std::function<Status(const AttrSlice&, FunctionDef*)>.
1093 //
1094 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
1095 // definition of a brain function which compute the gradient for the
1096 // <op_name> when the <op_name> is instantiated with the given attrs.
1097 //
1098 // E.g.,
1099 //
1100 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
1101 // bool transpose_a;
1102 // TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
1103 // bool transpose_b;
1104 // TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
1105 // DataType dtype;
1106 // TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
1107 // if (!transpose_a && !transpose_b) {
1108 // *g = FunctionDefHelper::Define(
1109 // "MatMulGrad",
1110 // {"x:T ", "y:T", "dz:T"}, // Inputs to this function
1111 // {"dx:T", "dy:T"}, // Outputs from this function
1112 // {"T: {float, double}"}, // Attributes needed by this function
1113 // {
1114 // {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
1115 // {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
1116 // {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
1117 // {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
1118 // });
1119 // } else {
1120 // ... ...
1121 // }
1122 // return Status::OK();
1123 // }
1124 //
1125 // NOTE: $T is substituted with the type variable "T" when the
1126 // gradient function MatMul is instantiated.
1127 //
1128 // TODO(zhifengc): Better documentation somewhere.
1129
1130 // Macros to define a gradient function factory for a primitive
1131 // operation.
1132 #define REGISTER_OP_GRADIENT(name, fn) \
1133 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
1134
1135 #define REGISTER_OP_NO_GRADIENT(name) \
1136 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
1137
1138 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
1139 REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
1140
1141 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
1142 static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \
1143 SHOULD_REGISTER_OP_GRADIENT && \
1144 ::tensorflow::gradient::RegisterOp(name, fn)
1145
1146 namespace gradient {
1147 // Register a gradient creator for the "op".
1148 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
1149 bool RegisterOp(const std::string& op, Creator func);
1150
1151 // Returns OK the gradient creator for the "op" is found (may be
1152 // nullptr if REGISTER_OP_NO_GRADIENT is used.
1153 Status GetOpGradientCreator(const std::string& op, Creator* creator);
1154 }; // namespace gradient
1155
1156 // Declare explicit instantiations of GetAttr
1157 #define GET_ATTR(T) \
1158 extern template Status FunctionLibraryDefinition::GetAttr( \
1159 const Node&, const string&, T*) const; \
1160 extern template Status FunctionLibraryDefinition::GetAttr( \
1161 const NodeDef&, const string&, T*) const;
1162 GET_ATTR(string)
1163 GET_ATTR(bool)
1164 #undef GET_ATTR
1165
1166 } // end namespace tensorflow
1167
1168 #endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
1169