1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
18
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/platform/platform.h"
24 // clang-format on
25
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/variant.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/registration/registration.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/lib/gtl/flatmap.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/protobuf.h"
44 #include "tensorflow/core/platform/threadpool_interface.h"
45 #include "tensorflow/core/protobuf/config.pb.h"
46 #if !defined(IS_MOBILE_PLATFORM)
47 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
48 #endif // IS_MOBILE_PLATFORM
49
50 namespace tensorflow {
51
52 class CancellationManager;
53 class CollectiveExecutor;
54 class DeviceSet;
55 class Graph;
56 class GraphDef;
57 class OpKernel;
58 class ProcessFunctionLibraryRuntime;
59 class ResourceMgr;
60 class Rendezvous;
61 class ScopedStepContainer;
62 class StepStatsCollectorInterface;
63 class Node;
64
65 // FunctionDefHelper::Create is a convenient helper to construct a
66 // FunctionDef proto.
67 // E.g.,
68 // FunctionDef my_func = FunctionDefHelper::Create(
69 // "my_func_name",
70 // {"x:T", "y:T" /* one string per argument */},
71 // {"z:T" /* one string per return value */},
72 // {"T: {float, double}" /* one string per attribute */},
73 // {
74 // {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
75 // /* one entry per function node */
76 // },
77 // /* Mapping between function returns and function node outputs. */
78 // {{"z", "o:z"}});
79 //
80 // For the old Function::Node approach, use FunctionDefHelper::Define()
81 // E.g.,
82 // FunctionDef my_func = FunctionDefHelper::Define(
83 // "my_func_name",
84 // {"x:T", "y:T" /* one string per argument */},
85 // {"z:T" /* one string per return value */},
86 // {"T: {float, double}" /* one string per attribute */},
87 // {
88 // {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
89 // /* one entry per function node */
90 // });
91 class FunctionDefHelper {
92 public:
93 // AttrValueWrapper has copy constructors for the type T so that
94 // it's easy to construct a simple AttrValue proto.
95 //
96 // If T is a string type (const char*, string, or StringPiece), and
97 // it starts with "$", we construct a AttrValue of "placeholder".
98 //
99 // E.g.,
100 // std::<string, AttrValueWrapper> x = {"T", "$T"}
101 // is a named attr value placeholder.
102 struct AttrValueWrapper {
103 AttrValue proto;
104
AttrValueWrapperAttrValueWrapper105 AttrValueWrapper() {}
106
107 template <typename T>
AttrValueWrapperAttrValueWrapper108 AttrValueWrapper(T val) { // NOLINT(runtime/explicit)
109 SetAttrValue(val, &proto);
110 }
111
112 private:
113 void InitFromString(StringPiece val);
114 };
115
116 // Constructs an AttrValue.func given the "name" and "attrs".
117 static AttrValueWrapper FunctionRef(
118 const std::string& name,
119 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
FunctionRef(const std::string & name)120 static AttrValueWrapper FunctionRef(const std::string& name) {
121 return FunctionRef(name, {});
122 }
123
124 // Node is used to construct FunctionDef.Node using initialization
125 // lists. E.g.,
126 // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y
127 //
128 // If the op has no inputs, then name is be specified.
129 // Node n = {{}, "AssignVariable", {"resource", "val"}, {{"dtype",
130 // "DT_FLOAT"},
131 // {"update0"}, "CPU:0", "update1"}}
132 struct Node {
133 // When constructing a NodeDef, the first entry in ret is used as
134 // the node name, the remaining values are ignored.
135 std::vector<string> ret;
136 std::string op;
137 std::vector<string> arg;
138 std::vector<std::pair<string, AttrValueWrapper>> attr;
139 std::vector<string> dep;
140 std::string device;
141
142 // Required if the op has zero outputs. Otherwise, ret[0] used as name if
143 // name is left empty.
144 std::string name;
145
GetNameNode146 std::string GetName() const {
147 if (!name.empty()) return name;
148 CHECK(!ret.empty());
149 return ret[0];
150 }
151
152 NodeDef ToNodeDef() const;
153 };
154
155 // Creates a FunctionDef from the given parameters. Node inputs must use
156 // function encoding (node_name:output_name[:output_index]).
157 // - `ret_def` holds a mapping from the function output names from `out_def`
158 // to the node outputs from `node_def`.
159 // - `control_ret_def` holds a mapping from the function control
160 // output names to the nodes from `node_def`.
161 static FunctionDef Create(
162 const std::string& function_name, gtl::ArraySlice<string> in_def,
163 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
164 gtl::ArraySlice<Node> node_def,
165 gtl::ArraySlice<std::pair<string, string>> ret_def,
166 gtl::ArraySlice<std::pair<string, string>> control_ret_def);
167
168 // Creates a FunctionDef from the given parameters. Node inputs must use
169 // function encoding (node_name:output_name[:output_index]).
170 // - `ret_def` holds a mapping from the function output names from `out_def`
171 // to the node outputs from `node_def`.
172 static FunctionDef Create(const std::string& function_name,
173 gtl::ArraySlice<string> in_def,
174 gtl::ArraySlice<string> out_def,
175 gtl::ArraySlice<string> attr_def,
176 gtl::ArraySlice<Node> node_def,
177 gtl::ArraySlice<std::pair<string, string>> ret_def);
178
179 // TODO(josh11b): Get rid of these and transition to the one above.
180 static FunctionDef Define(const std::string& function_name,
181 gtl::ArraySlice<string> arg_def,
182 gtl::ArraySlice<string> ret_def,
183 gtl::ArraySlice<string> attr_def,
184 gtl::ArraySlice<Node> node_def);
185
186 // Defines an anonymous function. I.e., its name is not relevant.
187 static FunctionDef Define(gtl::ArraySlice<string> arg_def,
188 gtl::ArraySlice<string> ret_def,
189 gtl::ArraySlice<string> attr_def,
190 gtl::ArraySlice<Node> node_def);
191
192 // Helpers to construct a constant scalar.
193 template <typename T>
Const(const std::string & name,const T & val)194 static Node Const(const std::string& name, const T& val) {
195 Node n = {{name}, "Const"};
196 const DataType dtype = DataTypeToEnum<T>::value;
197 n.attr.push_back({"dtype", dtype});
198 Tensor t(dtype, TensorShape({}));
199 t.scalar<T>()() = val;
200 n.attr.push_back({"value", t});
201 return n;
202 }
203
204 template <typename T>
Const(const std::string & name,gtl::ArraySlice<T> vals)205 static Node Const(const std::string& name, gtl::ArraySlice<T> vals) {
206 Node n = {{name}, "Const"};
207 const DataType dtype = DataTypeToEnum<T>::value;
208 n.attr.push_back({"dtype", dtype});
209 int64_t num = vals.size();
210 Tensor t(dtype, TensorShape({num}));
211 for (size_t i = 0; i < vals.size(); ++i) {
212 t.flat<T>()(i) = vals[i];
213 }
214 n.attr.push_back({"value", t});
215 return n;
216 }
217 };
218
219 template <>
AttrValueWrapper(const char * val)220 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
221 InitFromString(val);
222 }
223
224 template <>
AttrValueWrapper(const std::string & val)225 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
226 const std::string& val) {
227 InitFromString(val);
228 }
229
230 template <>
AttrValueWrapper(StringPiece val)231 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
232 InitFromString(val);
233 }
234
235 // Instantiate a function.
236 //
237 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
238 // containing placeholders. InstantiateFunction binds these
239 // placeholders and produces an instantiated function encoded in
240 // "result.gdef". The value to substitute a placeholder is given by
241 // "attr_values", which is a map from a placeholder name to an attr
242 // value.
243 //
244 // InstantiateFunction calls "get_function" to find signatures of other
245 // functions and primitive ops.
246
247 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
248 // and opdef is filled with a pointer to the corresponding signature
249 // (a OpDef proto). Otherwise, returns an error.
250 typedef std::function<Status(const string&, const OpDef**)>
251 GetFunctionSignature;
252
253 struct InstantiationResult {
254 DataTypeVector arg_types;
255 DataTypeVector ret_types;
256 std::vector<NodeDef> nodes;
257 };
258 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
259 GetFunctionSignature get_function,
260 InstantiationResult* result);
261
262 // Returns a debug string for a function definition.
263 //
264 // The returned text is multiple-line. It is intended to be
265 // human-readable rather than being friendly to parsers. It is _NOT_
266 // intended to be the canonical string representation of "func_def".
267 // Particularly, it may not include all information presented in
268 // "func_def" (e.g., comments, description of the function arguments,
269 // etc.)
270 std::string DebugString(const FunctionDef& func_def);
271 std::string DebugString(const GraphDef& instantiated_func_def);
272 std::string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
273
274 // Returns a debug string for a top level graph (the main program and
275 // its supporting functions defined in its library).
276 std::string DebugStringWhole(const GraphDef& gdef);
277
278 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
279 // of NodeDefs doesn't matter.
280 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
281
282 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
283 // In other words, if two fdefs compare equal, their hash values will be the
284 // same.
285 uint64 FunctionDefHash(const FunctionDef& fdef);
286
287 class CallFrameInterface {
288 public:
~CallFrameInterface()289 virtual ~CallFrameInterface() {}
290
291 virtual size_t num_args() const = 0;
292 virtual size_t num_retvals() const = 0;
293
294 virtual Status GetArg(int index, const Tensor** val) = 0;
295
296 // Optimized implementation of `GetArg()` that allows the caller to take
297 // ownership of the tensor. This method may only be called once per
298 // value of `index` and `CallFrameInterface` instance.
299 //
300 // REQUIRES: `this->CanConsumeArg(index) == true`.
ConsumeArg(int index,Tensor * val)301 virtual void ConsumeArg(int index, Tensor* val) {
302 LOG(ERROR) << "This `CallFrameInterface` implementation does not support "
303 "consuming arguments.";
304 }
CanConsumeArg(int index)305 virtual bool CanConsumeArg(int index) const { return false; }
306
307 virtual Status SetRetval(int index, const Tensor& val) = 0;
308 };
309
310 // Represents a function call frame. I.e., the data structure used to
311 // pass arguments to a function and retrieve its results.
312 //
313 // Runtime must arrange accesses to one FunctionCallFrame s.t.
314 // 1. SetArgs() happens before any GetArg();
315 // 2. GetRetvals happens after all SetRetval();
316 class FunctionCallFrame : public CallFrameInterface {
317 public:
318 FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
319 ~FunctionCallFrame() override;
320
321 // Caller methods.
322 Status SetArgs(gtl::ArraySlice<Tensor> args);
323 Status GetRetvals(std::vector<Tensor>* rets) const;
324
325 // Moves the return values from the frame to rets. If allow_dead_tensors is
326 // false it will fail if any of the retvals do not have a value.
327 Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors);
328
num_args()329 size_t num_args() const override { return arg_types_.size(); }
num_retvals()330 size_t num_retvals() const override { return ret_types_.size(); }
331
332 // Callee methods.
333 Status GetArg(int index, const Tensor** val) override;
334 Status SetRetval(int index, const Tensor& val) override;
335
336 private:
337 DataTypeVector arg_types_;
338 DataTypeVector ret_types_;
339 gtl::InlinedVector<Tensor, 4> args_;
340 struct Retval {
341 bool has_val = false;
342 Tensor val;
343 };
344 gtl::InlinedVector<Retval, 4> rets_;
345
346 TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
347 };
348
349 // Language agnostic stack traces.
350 class AbstractStackTrace {
351 public:
352 struct TracePrintingOptions {
353 // Show inline the contents of each stack line.
354 bool show_line_contents = false;
355
356 // Drop the common largest prefix of all filenames in stack frames.
357 bool filter_common_prefix = false;
358
359 // Do not show internal frames.
360 bool drop_internal_frames = false;
361 };
362
~AbstractStackTrace()363 virtual ~AbstractStackTrace() {}
364
365 // The returned span is alive as long as the AbstractStackTrace is alive.
366 virtual absl::Span<StackFrame const> ToFrames() const = 0;
367
368 // Returns the last stack frame from user code, attempting to ignore the
369 // framework code. Returns an empty frame if no such stack frame was found.
370 virtual StackFrame LastUserFrame() const = 0;
371 virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
372 };
373
374 using StackTracesMap =
375 std::unordered_map<std::string,
376 std::shared_ptr<tensorflow::AbstractStackTrace>>;
377
378 // Helper to maintain a map between function names in a given
379 // FunctionDefLibrary and function definitions.
380 //
381 // This class is thread-safe.
382 class FunctionLibraryDefinition : public OpRegistryInterface {
383 public:
384 // Ops created for function arguments bear the name given by `kArgOp`; those
385 // created for return values bear the name given by `kRetOp`.
386 static constexpr const char* const kArgOp = "_Arg";
387 static constexpr const char* const kDeviceArgOp = "_DeviceArg";
388 static constexpr const char* const kRetOp = "_Retval";
389 static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
390 static constexpr const char* const kIntsOnDeviceAttr =
391 "experimental_ints_on_device";
392 static constexpr const char* const kSharedRendezvousAttr =
393 "shared_rendezvous";
394
395 static constexpr const char* const kGradientOp = "SymbolicGradient";
396 static constexpr const char* const kFuncAttr = "f";
397
398 // Note: This constructor grabs `lib_def`'s lock in shared mode.
399 FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
400 FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
401 const FunctionDefLibrary& lib_def = {});
402 ~FunctionLibraryDefinition() override;
403
404 FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
405 delete;
406
407 // Returns True if the library contains `func`, False otherwise.
408 bool Contains(const std::string& func) const;
409
410 // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
411 // returns its definition proto.
412 //
413 // NB: This function returns a borrowed pointer, which can be invalidated by a
414 // subsequent call to `ReplaceFunction()` with the given name.
415 const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_);
416
417 // Adds function definition 'fdef' to this function library.
418 // Returns status 'ok' on success, or error otherwise. This is a no-op if
419 // 'fdef' already exists in this function library.
420 // If 'fdef' is successfully added to the library, it will be accessible
421 // from 'LookUp' and included in the proto returned by 'ToProto'.
422 // This operation is atomic.
423 //
424 // Associates `graph` with a function `func_name`. Lifetime assumption:
425 // `graph` has to outlive all instantiated graphs.
426 Status AddFunctionDef(const FunctionDef& fdef,
427 const StackTracesMap& stack_traces = {})
428 TF_LOCKS_EXCLUDED(mu_);
429
430 // Adds gradient definition 'grad' to this function library.
431 // This is a no-op if 'grad' already exists in this function library.
432 // If 'grad' is successfully added, it will be accessible via 'FindGradient'
433 // and included in the proto returned by 'ToProto'.
434 // This operation is atomic.
435 Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
436
437 // Replaces the function corresponding to `func` with `fdef`. Returns
438 // a non-OK status if "func" was not found in the library, OK otherwise.
439 // Please be careful when replacing function: make sure all previous pointers
440 // returned by `Find()` are no longer in use.
441 Status ReplaceFunction(const std::string& func, const FunctionDef& fdef,
442 const StackTracesMap& stack_traces = {})
443 TF_LOCKS_EXCLUDED(mu_);
444
445 // Replaces the gradient corresponding to `grad.function_name()`. Returns
446 // a non-OK status if "grad.function_name()" was not found in the library, OK
447 // otherwise.
448 Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
449
450 // Removes the function corresponding to 'func'. Returns a non-OK status if
451 // 'func' was not found in the library, OK otherwise.
452 // Please be careful when removing function: make sure there are no other
453 // nodes using the function, and all previous pointers returned by `Find()`
454 // are no longer in use.
455 Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
456
457 // Removes all the functions and gradient functions.
458 void Clear() TF_LOCKS_EXCLUDED(mu_);
459
460 // Adds the functions and gradients in 'other' to this function library.
461 // Duplicate functions and gradients are ignored.
462 // This operation is atomic.
463 Status AddLibrary(const FunctionLibraryDefinition& other)
464 TF_LOCKS_EXCLUDED(mu_);
465
466 // Adds the functions and gradients in 'lib_def' to this function library.
467 // Duplicate functions and gradients are ignored.
468 // This operation is atomic.
469 Status AddLibrary(const FunctionDefLibrary& lib_def) TF_LOCKS_EXCLUDED(mu_);
470
471 // If the gradient function for 'func' is specified explicitly in
472 // the library, returns the gradient function name. Otherwise,
473 // returns an empty string.
474 std::string FindGradient(const std::string& func) const
475 TF_LOCKS_EXCLUDED(mu_);
476
477 // OpRegistryInterface method. Useful for constructing a Graph.
478 //
479 // If "op" is defined in the library, returns its signature.
480 // Otherwise, assume "op" is a primitive op and returns its op
481 // signature and shape inference function.
482 //
483 // NB: This function outputs a borrowed pointer, which can be invalidated by a
484 // subsequent call to `ReplaceFunction()` with the given name.
485 Status LookUp(const std::string& op_type_name,
486 const OpRegistrationData** op_reg_data) const override
487 TF_LOCKS_EXCLUDED(mu_);
488
489 // Generates new function name with the specified prefix that is unique
490 // across this library.
491 std::string UniqueFunctionName(StringPiece prefix) const
492 TF_LOCKS_EXCLUDED(mu_);
493
494 // Given a node def 'ndef', inspects attributes of the callee
495 // function to derive the attribute 'value' for 'attr'. Returns OK
496 // iff the attribute is given by the function's definition.
497 // TODO(irving): Remove; keep only the const Node& version.
498 template <typename T>
499 Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const;
500
501 // Given a node, inspects attributes of the callee function to derive the
502 // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
503 // function's definition.
504 template <typename T>
505 Status GetAttr(const Node& node, const std::string& attr, T* value) const;
506
507 // Returns a proto representation of the state of this function library.
508 FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_);
509
num_functions()510 size_t num_functions() const {
511 tf_shared_lock l(mu_);
512 return function_defs_.size();
513 }
514
515 // Returns all the function names in the FunctionLibraryDefinition.
516 std::vector<string> ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_);
517
default_registry()518 const OpRegistryInterface* default_registry() const {
519 return default_registry_;
520 }
set_default_registry(const OpRegistryInterface * registry)521 void set_default_registry(const OpRegistryInterface* registry) {
522 default_registry_ = registry;
523 }
524
525 // Returns a copy of `*this` with only the subset of functions that are
526 // reachable from the nodes of `graph` or `func`.
527 FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
528 FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
529
530 // Copies the function named `func` from `other` to this
531 // FunctionLibraryDefinition.
532 // REQUIRES: `this->default_registry() == other.default_registry()`.
533 // Returns OK on success, or error otherwise. This is a no-op if a function
534 // name `func` already exists in this function library, and has the same
535 // implementation as in `other`. If the implementations conflict, an invalid
536 // argument error is returned.
537 Status CopyFunctionDefFrom(const std::string& func,
538 const FunctionLibraryDefinition& other)
539 TF_LOCKS_EXCLUDED(mu_);
540
541 // Returns graph with debug stack traces for the given function, or `nullptr`
542 // if none found.
GetStackTraces(const std::string & func_name)543 const StackTracesMap& GetStackTraces(const std::string& func_name) const {
544 tf_shared_lock l(mu_);
545 std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name);
546 if (entry) {
547 return entry->stack_traces;
548 }
549 static const auto* empty_map = new StackTracesMap;
550 return *empty_map;
551 }
552
553 private:
554 // Shape inference for functions is handled separately by ShapeRefiner.
555
556 struct FunctionDefAndOpRegistration {
557 explicit FunctionDefAndOpRegistration(
558 const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {});
559
560 const FunctionDef fdef;
561 const OpRegistrationData op_registration_data;
562 const StackTracesMap stack_traces;
563 };
564
565 std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
566 const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_);
567 std::string FindGradientHelper(const std::string& func) const
568 TF_SHARED_LOCKS_REQUIRED(mu_);
569
570 Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
571 bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
572
573 // Same as AddFunctionDef/AddGradientDef except these methods set
574 // `added` to true if the `fdef`/`grad` were actually added to this.
575 Status AddFunctionDefHelper(const FunctionDef& fdef,
576 const StackTracesMap& stack_traces, bool* added)
577 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
578 Status AddGradientDefHelper(const GradientDef& grad, bool* added)
579 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
580
581 // Helper function for GetAttr. Returns the FunctionDef* to get the
582 // attr from.
583 const FunctionDef* GetAttrImpl(const NodeDef& ndef) const
584 TF_LOCKS_EXCLUDED(mu_);
585
586 // Remove all functions in `funcs` and all gradients of functions in
587 // `funcs_with_grads` from this library.
588 Status Remove(const std::vector<string>& funcs,
589 const std::vector<string>& funcs_with_grads)
590 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
591
592 // Remove `func` from the library. Returns non-OK Status unless `func` is in
593 // the library. This should only be called when there is a guarantee that the
594 // function being removed hasn't been retrieved with `Find`.
595 Status RemoveFunctionHelper(const std::string& func)
596 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
597
598 // Remove gradient of function `func` from the library. Returns non-OK Status
599 // unless `func` has a gradient.
600 Status RemoveGradient(const std::string& func)
601 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
602
603 mutable mutex mu_;
604 const OpRegistryInterface* default_registry_;
605 gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>>
606 function_defs_ TF_GUARDED_BY(mu_);
607 gtl::FlatMap<string, string> func_grad_ TF_GUARDED_BY(mu_);
608 };
609
610 // Forward declare. Defined in common_runtime/function.h
611 struct FunctionBody;
612
613 // Forward declare. Defined in common_runtime/device.h
614 class Device;
615 // Forward declare. Defined in common_runtime/device_mgr.h
616 class DeviceMgr;
617
618 // Index of an _Arg node.
619 struct FunctionArgIndex {
FunctionArgIndexFunctionArgIndex620 explicit FunctionArgIndex(const int index) : index(index) {}
FunctionArgIndexFunctionArgIndex621 FunctionArgIndex(const int index, const int sub_index)
622 : index(index), sub_index(sub_index) {}
623
624 // The value of the attribute "Index" of the _Arg node.
625 int index;
626 // Set only when the _Arg node represents multiple arguments (e.g. an _Arg
627 // node is replicated to multiple devices/subgraphs). Use sub-index to
628 // distinguish arguments with the same index.
629 int sub_index = -1;
630 };
631
632 class FunctionLibraryRuntime {
633 public:
~FunctionLibraryRuntime()634 virtual ~FunctionLibraryRuntime() {}
635
636 // Instantiate a function with the given "attrs".
637 //
638 // Returns OK and fills in "handle" if the instantiation succeeds.
639 // Otherwise returns an error and "handle" is undefined.
640 struct InstantiateOptions {
641 // The canonical device name of the device on which the function
642 // should be instantiated. If empty, the function will be
643 // instantiated on the local device.
644 std::string target;
645
646 // Should the function be instantiated as a multi-device function?
647 bool is_multi_device_function = false;
648
649 // If true, graph passes will be skipped when instantiating the function
650 // since they have already run on the main function side.
651 bool is_component_function = false;
652
653 // For multi-device functions, a vector of canonical device names for
654 // function's inputs. The device of resource inputs must be the device
655 // backing the resource, not the CPU device backing the resource handle.
656 // Must have the same length as number of inputs to the function.
657 std::vector<string> input_devices;
658
659 // For multi-device functions, a vector of canonical device names for
660 // function's outputs.
661 //
662 // (a) If specified (must have the same length as number of outputs):
663 //
664 // Specified devices will be assigned to Retval nodes inserted into the
665 // function body graph in place of function outputs. It is allowed to
666 // specify output device as empty string, in this case Retval device
667 // assignment will be inferred later when function graph will be placed
668 // before partitioning (this is required for resource outputs). Placer will
669 // respect colocation constraints.
670 //
671 // (b) If not specified:
672 //
673 // Function runtime will infer Retval device by following input edges, until
674 // it will reach a node with a device specification. This device
675 // specification must identify a unique device, i.e. a general specification
676 // like "job:foo" matching multiple devices will result in an error.
677 //
678 // IMPORTANT: Resource outputs
679 //
680 // Multi device functions might return resources on a devices different from
681 // the function call device. If output device is not specified for the
682 // resource output, and node producing that resource is a function call,
683 // runtime will leave device specification empty and will rely on Placer to
684 // infer correct device.
685 std::vector<string> output_devices;
686
687 // If set, it indicates the original output indices of a component function.
688 absl::optional<std::vector<int>> ret_indices = absl::nullopt;
689
690 // Maps from a CompositeDevice name to a list of underlying physical
691 // devices.
692 absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
693
694 // This interface is EXPERIMENTAL and subject to change.
695 //
696 // For multi-device functions, a mapping from _Arg node index to type and
697 // shape for input resources.
698 // REQUIRES: if input_resource_dtypes_and_shapes.count(i) > 0 then i-th
699 // argument type must be DT_RESOURCE.
700 std::unordered_map<int, DtypeAndPartialTensorShape>
701 input_resource_dtypes_and_shapes;
702
703 // This interface is EXPERIMENTAL and subject to change.
704 //
705 // If non-null, the runtime will use `lib_def` to resolve function(s) named
706 // in `function_name` and `attrs`. Otherwise, the runtime will use its
707 // internal library.
708 //
709 // NOTE(mrry): If provided, all functions defined in `lib_def` must be
710 // self-contained, and cannot refer to functions defined in other libraries.
711 const FunctionLibraryDefinition* lib_def = nullptr;
712
713 // This interface is EXPERIMENTAL and subject to change.
714 //
715 // If non-empty, the runtime will use `state_handle` to identify
716 // cached state related the instantiated function. Two functions
717 // of the same name and attrs, instantiated with the same
718 // `state_handle` will have the same handle and share the same
719 // state (in stateful kernels); and two functions with different
720 // values for `state_handle` will have independent state.
721 std::string state_handle;
722
723 // This interface is EXPERIMENTAL and subject to change.
724 //
725 // Instantiates the function using an executor of the given type. If empty,
726 // the default TensorFlow executor will be used.
727 std::string executor_type;
728
729 // If true, the runtime will attempt to create kernels for the function at
730 // instantiation time, rather than on the first run. This can be used to
731 // surface errors earlier.
732 bool create_kernels_eagerly = false;
733
734 // This interface is EXPERIMENTAL and subject to change.
735 //
736 // Instantiates the function with the provided config_proto.
737 ConfigProto config_proto;
738
739 // If provided, this optimization function will be invoked before
740 // the placer for multi-device functions.
741 std::function<Status(std::vector<string> /*ret_node_names*/,
742 std::vector<string> /*keep_node_names*/,
743 FunctionLibraryDefinition*, const DeviceSet&,
744 Device* /*cpu_device*/, std::unique_ptr<Graph>*)>
745 optimize_graph_fn;
746
747 // If set, partitioned functions will be added to `graph_collector`.
748 // `graph_collector` must be alive during the call to Instantiate.
749 GraphCollector* graph_collector = nullptr;
750
751 // Indicates whether the multi-device function backend should default the
752 // placement of ops without request device to `target`.
753 bool default_device_to_target = true;
754
755 // If true, the optimized Graph will be stored so that
756 // `FunctionLibraryRuntime::DebugString(handle)` contains the optimized
757 // Graph. Otherwise, the unoptimized function Graph will be returned.
758 bool include_optimized_graph_in_debug_string = false;
759
760 // If true, the function library runtime cache the function instantiation.
761 bool use_function_cache = false;
762 };
763 typedef uint64 Handle;
764 virtual Status Instantiate(const std::string& function_name, AttrSlice attrs,
765 const InstantiateOptions& options,
766 Handle* handle) = 0;
Instantiate(const std::string & function_name,AttrSlice attrs,Handle * handle)767 Status Instantiate(const std::string& function_name, AttrSlice attrs,
768 Handle* handle) {
769 auto opts = absl::make_unique<InstantiateOptions>();
770 return Instantiate(function_name, attrs, *opts, handle);
771 }
772
773 // Releases state associated with the handle.
774 virtual Status ReleaseHandle(Handle handle) = 0;
775
776 // Returns the function body for the instantiated function given its
777 // handle 'h'. Returns nullptr if "h" is not found.
778 //
779 // *this keeps the ownership of the returned object, which remains alive
780 // as long as *this.
781 virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
782
783 // Returns the return types for the function identified by handle `h`.
784 virtual Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0;
785
786 // Asynchronously invokes the instantiated function identified by
787 // "handle".
788 //
789 // If function execution succeeds, "done" is called with OK and
790 // "*rets" is filled with the function's return values. Otherwise,
791 // "done" is called with an error status.
792 //
793 // Does not take ownership of "rets".
794 // In the cross-process scenario, runner isn't used for making the Async
795 // RPC calls.
796 struct Options {
OptionsOptions797 Options() {}
OptionsOptions798 explicit Options(const int64_t step_id) : step_id(step_id) {}
799 // Choose a step ID that is guaranteed not to clash with any
800 // Session-generated step ID. DirectSession only generates
801 // non-negative step IDs (contiguous, starting from 0), and
802 // MasterSession generates 56-bit random step IDs whose MSB is
803 // always 0, so a negative random step ID should suffice.
804 const int64 step_id = -std::abs(static_cast<int64>(random::New64()));
805
806 // op_id of the function running in eager mode. Set when we want to copy
807 // remote outputs lazily. All components of a remote multi-device function
808 // should use the same op_id, in order to correctly map remote output
809 // tensors to the remote TensorHandles in the default device.
810 absl::optional<int64> op_id = absl::nullopt;
811
812 RendezvousInterface* rendezvous = nullptr;
813 CancellationManager* cancellation_manager = nullptr;
814 CollectiveExecutor* collective_executor = nullptr;
815 ScopedStepContainer* step_container = nullptr;
816 StepStatsCollectorInterface* stats_collector = nullptr;
817 CoordinationServiceAgent* coordination_service_agent = nullptr;
818
819 std::function<void(std::function<void()>)>* runner = nullptr;
820
821 // Parameters for remote function execution.
822 bool remote_execution = false;
823 std::string source_device = ""; // Fully specified device name.
824
825 // Allocator attributes specifying where the args are / rets should be put.
826 // These should either be {} or match the length of args / retvals. If {},
827 // the default allocator attributes will be assumed for all args / retvals.
828 std::vector<AllocatorAttributes> args_alloc_attrs;
829 std::vector<AllocatorAttributes> rets_alloc_attrs;
830
831 // If true, we create a new IntraProcessRendezvous, else use the existing
832 // one.
833 bool create_rendezvous = false;
834
835 // If True, allow returning dead tensors.
836 bool allow_dead_tensors = false;
837
838 // If True, hint that all kernels should be treated as "inexpensive", and
839 // hence executed on the scheduling thread.
840 bool run_all_kernels_inline = false;
841
842 // If not null, use this thread pool for intra op scheduling.
843 thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
844
845 // Returns a human readable representation of this.
846 std::string DebugString() const;
847 };
848 typedef std::function<void(const Status&)> DoneCallback;
849 virtual void Run(const Options& opts, Handle handle,
850 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
851 DoneCallback done) = 0;
852 virtual void Run(const Options& opts, Handle handle,
853 CallFrameInterface* call_frame, DoneCallback done) = 0;
854
855 virtual Status RunSync(Options opts, Handle handle,
856 gtl::ArraySlice<Tensor> args,
857 std::vector<Tensor>* rets) = 0;
858 virtual Status RunSync(Options opts, Handle handle,
859 CallFrameInterface* call_frame) = 0;
860
861 // Creates a "kernel" for the given NodeProperties "props".
862 //
863 // If succeeds, returns OK and the caller takes the ownership of the
864 // returned "*kernel". Otherwise, returns an error.
865 virtual Status CreateKernel(
866 const std::shared_ptr<const NodeProperties>& props,
867 OpKernel** kernel) = 0;
868
869 // Returns true iff the function named `function_name` is stateful.
870 //
871 // NOTE(mrry): This method assumes that the runtime is associated with a
872 // default function library, and looks up `function_name` in that library.
873 // It does not support overriding the function library.
874 virtual bool IsStateful(const std::string& function_name) const = 0;
875
876 // Returns the device on which the function executes.
877 virtual Device* device() = 0;
878 virtual const Device* device() const = 0;
879
880 // Returns the default runner in which the ops should be launched. If the
881 // device on which the function executes has a private thread pool, return
882 // runner on the device local thread pool.
883 virtual std::function<void(std::function<void()>)>* runner() = 0;
884
885 // Get the DeviceMgr from which the device was obtained.
886 virtual const DeviceMgr* device_mgr() const = 0;
887
888 // Returns the function library definition that backs this runtime.
889 //
890 // NOTE(mrry): The returned library definition is the default function library
891 // for this runtime. The caller may override the function library used by the
892 // runtime to instantiate functions, which will not be reflected in the return
893 // value of this function.
894 virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
895 const = 0;
896
897 // Returns the environment on which the function executes.
898 virtual Env* env() = 0;
899
900 // Returns the ConfigProto passed to the session used to create the function.
901 virtual const ConfigProto* const config_proto() = 0;
902
903 // Returns a debug string showing the definition of the function of
904 // 'handle'.
905 virtual std::string DebugString(Handle handle) = 0;
906
907 // Returns the graph version number.
908 virtual int graph_def_version() const = 0;
909
910 typedef uint64 LocalHandle;
911
912 // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to
913 // the caller), FunctionLibraryRuntime (owned by the returned
914 // ProcessFunctionLibraryRuntime), FunctionLibraryDefinition (transferring
915 // ownership to the caller). Note that both the ProcessFunctionLibraryRuntime
916 // and FunctionLibraryRuntime borrow a pointer to the
917 // FunctionLibraryDefinition and so the FunctionLibraryDefinition should
918 // outlive both.
919 //
920 // The `skip_flib_def` argument controls whether the method should clone the
921 // FunctionLibraryDefinition (default behavior) or return an empty function
922 // library. The latter is used by tf.data, which manages
923 // FunctionLibraryDefinitions for its functions independently (and passes
924 // these into the FunctionLibraryRuntime through an overlay), to avoid linear
925 // runtime w.r.t. to number of functions in the current function library.
926 virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
927 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
928 FunctionLibraryRuntime** out_flr,
929 bool skip_flib_def = false) = 0;
930
931 // Returns the name of the executor class (in the sense of
932 // `ExecutorFactory::GetFactory()`) that will be used based on the given
933 // dynamic `options` and static `attrs`. If none is specified, this method
934 // will return an empty string, which leaves the decision up to the runtime.
935 static std::string ExecutorType(const InstantiateOptions& options,
936 AttrSlice attrs);
937 };
938
939 // Returns the device of the `arg_index`-th function input. Update
940 // `composite_devices` if the input device is a composite device.
941 std::string GetFunctionResourceInputDevice(
942 const Tensor& input, const int arg_index, const FunctionDef& function_def,
943 absl::flat_hash_map<string, std::vector<string>>* composite_devices);
944
945 // Returns a canonicalized string for the instantiation of the function of the
946 // given "name", attributes "attrs", and "options".
947 //
948 // The returned string is guaranteed to be stable within one address space. But
949 // it may be change as the implementation evolves. Therefore, it should not be
950 // persisted or compared across address spaces.
951 std::string Canonicalize(
952 const std::string& funcname, AttrSlice attrs,
953 const FunctionLibraryRuntime::InstantiateOptions& options);
954 std::string Canonicalize(const std::string& funcname, AttrSlice attrs);
955
956 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
957 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
958
959 class CustomKernelCreator {
960 public:
~CustomKernelCreator()961 virtual ~CustomKernelCreator() {}
962
963 // Given a NodeDef 'node_def' and the function library runtime 'flr',
964 // validate if the class supports creating such a kernel.
965 virtual bool CanCreateKernel(
966 const FunctionLibraryRuntime& flr,
967 const std::shared_ptr<const NodeProperties>& props) const = 0;
968
969 // Given a supported NodeDef, returns a kernel that computes the node.
970 virtual Status CreateKernel(
971 FunctionLibraryRuntime* flr,
972 const std::shared_ptr<const NodeProperties>& props,
973 std::unique_ptr<OpKernel>* kernel) const = 0;
974 };
975
976 typedef
977 #if !defined(IS_MOBILE_PLATFORM)
978 absl::variant<Tensor, eager::RemoteTensorHandle*>
979 FunctionArg;
980 #else
981 absl::variant<Tensor>
982 FunctionArg;
983 #endif
984
985 // Either a local tensor or the shape of a remote tensor.
986 typedef absl::variant<Tensor, TensorShape> FunctionRet;
987
988 // Used to instantiate and run functions in a distributed system.
989 class DistributedFunctionLibraryRuntime {
990 public:
~DistributedFunctionLibraryRuntime()991 virtual ~DistributedFunctionLibraryRuntime() {}
992
993 // Instantiate a function on a remote target specified in `options.target`, by
994 // sending the name and definition of the function to the remote worker. The
995 // local `handle` is filled for the instantiated function data and can be used
996 // for subsequent run function calls on the remote target.
997 virtual void Instantiate(
998 const std::string& function_name,
999 const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
1000 const FunctionLibraryRuntime::InstantiateOptions& options,
1001 FunctionLibraryRuntime::LocalHandle* handle,
1002 FunctionLibraryRuntime::DoneCallback done) = 0;
1003
1004 // Run an instantiated remote function (specified by `handle`) with a list of
1005 // input Tensors in `args` and get its output Tensors in `rets`. The input
1006 // tensor data will be sent with the function execution request, and must be
1007 // available on the current caller side.
1008 // opts.runner isn't used for execution.
1009 virtual void Run(const FunctionLibraryRuntime::Options& opts,
1010 FunctionLibraryRuntime::LocalHandle handle,
1011 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
1012 FunctionLibraryRuntime::DoneCallback done) = 0;
1013
1014 // Run an instantiated remote function (specified by `handle`) with a list of
1015 // input Tensors or RemoteTensorHandles as `args` and get its output Tensors
1016 // or TensorShapes in `rets`. When using RemoteTensorHandles as function
1017 // inputs or TensorShapes as outputs, the corresponding tensor data will be
1018 // resolved on the remote worker, so it is not required to be locally
1019 // available on the caller side. Using RemoteTensorHandle inputs is not
1020 // supported in TensorFlow v1 runtime.
1021 virtual void Run(const FunctionLibraryRuntime::Options& opts,
1022 FunctionLibraryRuntime::LocalHandle handle,
1023 gtl::ArraySlice<FunctionArg> args,
1024 std::vector<FunctionRet>* rets,
1025 FunctionLibraryRuntime::DoneCallback done) = 0;
1026
1027 // Clean up a previously instantiated function on remote worker.
1028 virtual void CleanUp(uint64 step_id,
1029 FunctionLibraryRuntime::LocalHandle handle,
1030 FunctionLibraryRuntime::DoneCallback done) = 0;
1031
1032 // DeviceMgr with *all* available devices (i.e., local and remote).
1033 virtual DeviceMgr* remote_device_mgr() const = 0;
1034 };
1035
1036 // Extracts the actual type from "attr_values" based on its definition
1037 // "arg_def".
1038 //
1039 // If "arg_def" is a N*T type, *is_type_list is set to false, and
1040 // *dtypes is set to be a vector of size N and each element is T.
1041 //
1042 // If "arg_def" is a list(type), *is_type_list is set to true, and
1043 // *dtypes is set to be a vector of types specified in attrs for
1044 // arg_def.
1045 //
1046 // Otherwise (arg_def is a simple type T), *is_type_list is set to
1047 // false, and *dtypes is set to a single element vector, whose only
1048 // element is T.
1049 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
1050 bool* is_type_list, DataTypeVector* dtypes);
1051
1052 // To register a gradient function for a builtin op, one should use
1053 // REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
1054 //
1055 // Typically, the c++ grad factory is a plan function that can be
1056 // converted into ::tensorflow::gradient::Creator, which is
1057 // std::function<Status(const AttrSlice&, FunctionDef*)>.
1058 //
1059 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
1060 // definition of a brain function which compute the gradient for the
1061 // <op_name> when the <op_name> is instantiated with the given attrs.
1062 //
1063 // E.g.,
1064 //
1065 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
1066 // bool transpose_a;
1067 // TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
1068 // bool transpose_b;
1069 // TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
1070 // DataType dtype;
1071 // TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
1072 // if (!transpose_a && !transpose_b) {
1073 // *g = FunctionDefHelper::Define(
1074 // "MatMulGrad",
1075 // {"x:T ", "y:T", "dz:T"}, // Inputs to this function
1076 // {"dx:T", "dy:T"}, // Outputs from this function
1077 // {"T: {float, double}"}, // Attributes needed by this function
1078 // {
1079 // {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
1080 // {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
1081 // {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
1082 // {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
1083 // });
1084 // } else {
1085 // ... ...
1086 // }
1087 // return Status::OK();
1088 // }
1089 //
1090 // NOTE: $T is substituted with the type variable "T" when the
1091 // gradient function MatMul is instantiated.
1092 //
1093 // TODO(zhifengc): Better documentation somewhere.
1094
1095 // Macros to define a gradient function factory for a primitive
1096 // operation.
1097 #define REGISTER_OP_GRADIENT(name, fn) \
1098 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
1099
1100 #define REGISTER_OP_NO_GRADIENT(name) \
1101 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
1102
1103 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
1104 REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
1105
1106 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
1107 static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \
1108 SHOULD_REGISTER_OP_GRADIENT && \
1109 ::tensorflow::gradient::RegisterOp(name, fn)
1110
1111 namespace gradient {
1112 // Register a gradient creator for the "op".
1113 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
1114 bool RegisterOp(const std::string& op, Creator func);
1115
1116 // Returns OK the gradient creator for the "op" is found (may be
1117 // nullptr if REGISTER_OP_NO_GRADIENT is used.
1118 Status GetOpGradientCreator(const std::string& op, Creator* creator);
1119 }; // namespace gradient
1120
1121 // Declare explicit instantiations of GetAttr
1122 #define GET_ATTR(T) \
1123 extern template Status FunctionLibraryDefinition::GetAttr( \
1124 const Node&, const string&, T*) const; \
1125 extern template Status FunctionLibraryDefinition::GetAttr( \
1126 const NodeDef&, const string&, T*) const;
1127 GET_ATTR(string)
1128 GET_ATTR(bool)
1129 #undef GET_ATTR
1130
1131 } // end namespace tensorflow
1132
1133 #endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
1134