1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
18
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/platform/platform.h"
24 // clang-format on
25
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/variant.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/selective_registration.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/lib/gtl/flatmap.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/protobuf.h"
44 #include "tensorflow/core/protobuf/config.pb.h"
45 #if !defined(IS_MOBILE_PLATFORM)
46 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
47 #endif // IS_MOBILE_PLATFORM
48
49 namespace tensorflow {
50
51 class CancellationManager;
52 class CollectiveExecutor;
53 class DeviceSet;
54 class Graph;
55 class GraphDef;
56 class OpKernel;
57 class ProcessFunctionLibraryRuntime;
58 class ResourceMgr;
59 class Rendezvous;
60 class ScopedStepContainer;
61 class StepStatsCollectorInterface;
62 class Node;
63
64 // FunctionDefHelper::Create is a convenient helper to construct a
65 // FunctionDef proto.
66 // E.g.,
67 // FunctionDef my_func = FunctionDefHelper::Create(
68 // "my_func_name",
69 // {"x:T", "y:T" /* one string per argument */},
70 // {"z:T" /* one string per return value */},
71 // {"T: {float, double}" /* one string per attribute */},
72 // {
73 // {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
74 // /* one entry per function node */
75 // },
76 // /* Mapping between function returns and function node outputs. */
77 // {{"z", "o:z"}});
78 //
79 // For the old Function::Node approach, use FunctionDefHelper::Define()
80 // E.g.,
81 // FunctionDef my_func = FunctionDefHelper::Define(
82 // "my_func_name",
83 // {"x:T", "y:T" /* one string per argument */},
84 // {"z:T" /* one string per return value */},
85 // {"T: {float, double}" /* one string per attribute */},
86 // {
87 // {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
88 // /* one entry per function node */
89 // });
90 class FunctionDefHelper {
91 public:
92 // AttrValueWrapper has copy constructors for the type T so that
93 // it's easy to construct a simple AttrValue proto.
94 //
95 // If T is a string type (const char*, string, or StringPiece), and
96 // it starts with "$", we construct a AttrValue of "placeholder".
97 //
98 // E.g.,
99 // std::<string, AttrValueWrapper> x = {"T", "$T"}
100 // is a named attr value placeholder.
101 struct AttrValueWrapper {
102 AttrValue proto;
103
AttrValueWrapperAttrValueWrapper104 AttrValueWrapper() {}
105
106 template <typename T>
AttrValueWrapperAttrValueWrapper107 AttrValueWrapper(T val) { // NOLINT(runtime/explicit)
108 SetAttrValue(val, &proto);
109 }
110
111 private:
112 void InitFromString(StringPiece val);
113 };
114
115 // Constructs an AttrValue.func given the "name" and "attrs".
116 static AttrValueWrapper FunctionRef(
117 const std::string& name,
118 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
FunctionRef(const std::string & name)119 static AttrValueWrapper FunctionRef(const std::string& name) {
120 return FunctionRef(name, {});
121 }
122
123 // Node is used to construct FunctionDef.Node using initialization
124 // lists. E.g.,
125 // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y
126 struct Node {
127 // When constructing a NodeDef, the first entry in ret is used as
128 // the node name, the remaining values are ignored.
129 std::vector<string> ret;
130 std::string op;
131 std::vector<string> arg;
132 std::vector<std::pair<string, AttrValueWrapper>> attr;
133 std::vector<string> dep;
134 std::string device;
135
136 NodeDef ToNodeDef() const;
137 };
138
139 // Creates a FunctionDef from the given parameters. Node inputs must use
140 // function encoding (node_name:output_name[:output_index]).
141 // - `ret_def` holds a mapping from the function output names from `out_def`
142 // to the node outputs from `node_def`.
143 // - `control_ret_def` holds a mapping from the function control
144 // output names to the nodes from `node_def`.
145 static FunctionDef Create(
146 const std::string& function_name, gtl::ArraySlice<string> in_def,
147 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
148 gtl::ArraySlice<Node> node_def,
149 gtl::ArraySlice<std::pair<string, string>> ret_def,
150 gtl::ArraySlice<std::pair<string, string>> control_ret_def);
151
152 // Creates a FunctionDef from the given parameters. Node inputs must use
153 // function encoding (node_name:output_name[:output_index]).
154 // - `ret_def` holds a mapping from the function output names from `out_def`
155 // to the node outputs from `node_def`.
156 static FunctionDef Create(const std::string& function_name,
157 gtl::ArraySlice<string> in_def,
158 gtl::ArraySlice<string> out_def,
159 gtl::ArraySlice<string> attr_def,
160 gtl::ArraySlice<Node> node_def,
161 gtl::ArraySlice<std::pair<string, string>> ret_def);
162
163 // TODO(josh11b): Get rid of these and transition to the one above.
164 static FunctionDef Define(const std::string& function_name,
165 gtl::ArraySlice<string> arg_def,
166 gtl::ArraySlice<string> ret_def,
167 gtl::ArraySlice<string> attr_def,
168 gtl::ArraySlice<Node> node_def);
169
170 // Defines an anonymous function. I.e., its name is not relevant.
171 static FunctionDef Define(gtl::ArraySlice<string> arg_def,
172 gtl::ArraySlice<string> ret_def,
173 gtl::ArraySlice<string> attr_def,
174 gtl::ArraySlice<Node> node_def);
175
176 // Helpers to construct a constant scalar.
177 template <typename T>
Const(const std::string & name,const T & val)178 static Node Const(const std::string& name, const T& val) {
179 Node n = {{name}, "Const"};
180 const DataType dtype = DataTypeToEnum<T>::value;
181 n.attr.push_back({"dtype", dtype});
182 Tensor t(dtype, TensorShape({}));
183 t.scalar<T>()() = val;
184 n.attr.push_back({"value", t});
185 return n;
186 }
187
188 template <typename T>
Const(const std::string & name,gtl::ArraySlice<T> vals)189 static Node Const(const std::string& name, gtl::ArraySlice<T> vals) {
190 Node n = {{name}, "Const"};
191 const DataType dtype = DataTypeToEnum<T>::value;
192 n.attr.push_back({"dtype", dtype});
193 int64 num = vals.size();
194 Tensor t(dtype, TensorShape({num}));
195 for (size_t i = 0; i < vals.size(); ++i) {
196 t.flat<T>()(i) = vals[i];
197 }
198 n.attr.push_back({"value", t});
199 return n;
200 }
201 };
202
203 template <>
AttrValueWrapper(const char * val)204 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
205 InitFromString(val);
206 }
207
208 template <>
AttrValueWrapper(const std::string & val)209 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
210 const std::string& val) {
211 InitFromString(val);
212 }
213
214 template <>
AttrValueWrapper(StringPiece val)215 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
216 InitFromString(val);
217 }
218
219 // Instantiate a function.
220 //
221 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
222 // containing placeholders. InstantiateFunction binds these
223 // placeholders and produces an instantiated function encoded in
224 // "result.gdef". The value to substitute a placeholder is given by
225 // "attr_values", which is a map from a placeholder name to an attr
226 // value.
227 //
228 // InstantiateFunction calls "get_function" to find signatures of other
229 // functions and primitive ops.
230
231 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
232 // and opdef is filled with a pointer to the corresponding signature
233 // (a OpDef proto). Otherwise, returns an error.
234 typedef std::function<Status(const string&, const OpDef**)>
235 GetFunctionSignature;
236
237 struct InstantiationResult {
238 DataTypeVector arg_types;
239 DataTypeVector ret_types;
240 std::vector<NodeDef> nodes;
241 };
242 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
243 GetFunctionSignature get_function,
244 InstantiationResult* result);
245
246 // Returns a debug string for a function definition.
247 //
248 // The returned text is multiple-line. It is intended to be
249 // human-readable rather than being friendly to parsers. It is _NOT_
250 // intended to be the canonical string representation of "func_def".
251 // Particularly, it may not include all information presented in
252 // "func_def" (e.g., comments, description of the function arguments,
253 // etc.)
254 std::string DebugString(const FunctionDef& func_def);
255 std::string DebugString(const GraphDef& instantiated_func_def);
256 std::string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
257
258 // Returns a debug string for a top level graph (the main program and
259 // its supporting functions defined in its library).
260 std::string DebugStringWhole(const GraphDef& gdef);
261
262 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
263 // of NodeDefs doesn't matter.
264 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
265
266 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
267 // In other words, if two fdefs compare equal, their hash values will be the
268 // same.
269 uint64 FunctionDefHash(const FunctionDef& fdef);
270
271 class CallFrameInterface {
272 public:
~CallFrameInterface()273 virtual ~CallFrameInterface() {}
274
275 virtual size_t num_args() const = 0;
276 virtual size_t num_retvals() const = 0;
277
278 virtual Status GetArg(int index, const Tensor** val) = 0;
279
280 // Optimized implementation of `GetArg()` that allows the caller to take
281 // ownership of the tensor. This method may only be called once per
282 // value of `index` and `CallFrameInterface` instance.
283 //
284 // REQUIRES: `this->CanConsumeArg(index) == true`.
ConsumeArg(int index,Tensor * val)285 virtual void ConsumeArg(int index, Tensor* val) {
286 LOG(ERROR) << "This `CallFrameInterface` implementation does not support "
287 "consuming arguments.";
288 }
CanConsumeArg(int index)289 virtual bool CanConsumeArg(int index) const { return false; }
290
291 virtual Status SetRetval(int index, const Tensor& val) = 0;
292 };
293
294 // Represents a function call frame. I.e., the data structure used to
295 // pass arguments to a function and retrieve its results.
296 //
297 // Runtime must arrange accesses to one FunctionCallFrame s.t.
298 // 1. SetArgs() happens before any GetArg();
299 // 2. GetRetvals happens after all SetRetval();
300 class FunctionCallFrame : public CallFrameInterface {
301 public:
302 FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
303 ~FunctionCallFrame() override;
304
305 // Caller methods.
306 Status SetArgs(gtl::ArraySlice<Tensor> args);
307 Status GetRetvals(std::vector<Tensor>* rets) const;
308
309 // Moves the return values from the frame to rets. If allow_dead_tensors is
310 // false it will fail if any of the retvals do not have a value.
311 Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors);
312
num_args()313 size_t num_args() const override { return arg_types_.size(); }
num_retvals()314 size_t num_retvals() const override { return ret_types_.size(); }
315
316 // Callee methods.
317 Status GetArg(int index, const Tensor** val) override;
318 Status SetRetval(int index, const Tensor& val) override;
319
320 private:
321 DataTypeVector arg_types_;
322 DataTypeVector ret_types_;
323 gtl::InlinedVector<Tensor, 4> args_;
324 struct Retval {
325 bool has_val = false;
326 Tensor val;
327 };
328 gtl::InlinedVector<Retval, 4> rets_;
329
330 TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
331 };
332
333 // Language agnostic stack traces.
334 class AbstractStackTrace {
335 public:
336 struct TracePrintingOptions {
337 // Show inline the contents of each stack line.
338 bool show_line_contents = false;
339
340 // Drop the common largest prefix of all filenames in stack frames.
341 bool filter_common_prefix = false;
342
343 // Do not show internal frames.
344 bool drop_internal_frames = false;
345 };
346
~AbstractStackTrace()347 virtual ~AbstractStackTrace() {}
348
349 // The returned span is alive as long as the AbstractStackTrace is alive.
350 virtual absl::Span<StackFrame const> ToFrames() const = 0;
351
352 // Returns the last stack frame from user code, attempting to ignore the
353 // framework code. Returns an empty frame if no such stack frame was found.
354 virtual StackFrame LastUserFrame() const = 0;
355 virtual std::string ToString(const TracePrintingOptions& opts) const = 0;
356 };
357
358 using StackTracesMap =
359 std::unordered_map<std::string,
360 std::shared_ptr<tensorflow::AbstractStackTrace>>;
361
362 // Helper to maintain a map between function names in a given
363 // FunctionDefLibrary and function definitions.
364 //
365 // This class is thread-safe.
366 class FunctionLibraryDefinition : public OpRegistryInterface {
367 public:
368 // Ops created for function arguments bear the name given by `kArgOp`; those
369 // created for return values bear the name given by `kRetOp`.
370 static constexpr const char* const kArgOp = "_Arg";
371 static constexpr const char* const kDeviceArgOp = "_DeviceArg";
372 static constexpr const char* const kRetOp = "_Retval";
373 static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
374 static constexpr const char* const kIntsOnDeviceAttr =
375 "experimental_ints_on_device";
376 static constexpr const char* const kSharedRendezvousAttr =
377 "shared_rendezvous";
378
379 static constexpr const char* const kGradientOp = "SymbolicGradient";
380 static constexpr const char* const kFuncAttr = "f";
381
382 // Note: This constructor grabs `lib_def`'s lock in shared mode.
383 FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
384 FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
385 const FunctionDefLibrary& lib_def);
386 ~FunctionLibraryDefinition() override;
387
388 FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
389 delete;
390
391 // Returns True if the library contains `func`, False otherwise.
392 bool Contains(const std::string& func) const;
393
394 // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
395 // returns its definition proto.
396 //
397 // NB: This function returns a borrowed pointer, which can be invalidated by a
398 // subsequent call to `ReplaceFunction()` with the given name.
399 const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_);
400
401 // Adds function definition 'fdef' to this function library.
402 // Returns status 'ok' on success, or error otherwise. This is a no-op if
403 // 'fdef' already exists in this function library.
404 // If 'fdef' is successfully added to the library, it will be accessible
405 // from 'LookUp' and included in the proto returned by 'ToProto'.
406 // This operation is atomic.
407 //
408 // Associates `graph` with a function `func_name`. Lifetime assumption:
409 // `graph` has to outlive all instantiated graphs.
410 Status AddFunctionDef(const FunctionDef& fdef,
411 const StackTracesMap& stack_traces = {})
412 TF_LOCKS_EXCLUDED(mu_);
413
414 // Adds gradient definition 'grad' to this function library.
415 // This is a no-op if 'grad' already exists in this function library.
416 // If 'grad' is successfully added, it will be accessible via 'FindGradient'
417 // and included in the proto returned by 'ToProto'.
418 // This operation is atomic.
419 Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
420
421 // Replaces the function corresponding to `func` with `fdef`. Returns
422 // a non-OK status if "func" was not found in the library, OK otherwise.
423 // Please be careful when replacing function: make sure all previous pointers
424 // returned by `Find()` are no longer in use.
425 Status ReplaceFunction(const std::string& func, const FunctionDef& fdef)
426 TF_LOCKS_EXCLUDED(mu_);
427
428 // Replaces the gradient corresponding to `grad.function_name()`. Returns
429 // a non-OK status if "grad.function_name()" was not found in the library, OK
430 // otherwise.
431 Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_);
432
433 // Removes the function corresponding to 'func'. Returns a non-OK status if
434 // 'func' was not found in the library, OK otherwise.
435 // Please be careful when removing function: make sure there are no other
436 // nodes using the function, and all previous pointers returned by `Find()`
437 // are no longer in use.
438 Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
439
440 // Removes all the functions and gradient functions.
441 void Clear() TF_LOCKS_EXCLUDED(mu_);
442
443 // Adds the functions and gradients in 'other' to this function library.
444 // Duplicate functions and gradients are ignored.
445 // This operation is atomic.
446 Status AddLibrary(const FunctionLibraryDefinition& other)
447 TF_LOCKS_EXCLUDED(mu_);
448
449 // Adds the functions and gradients in 'lib_def' to this function library.
450 // Duplicate functions and gradients are ignored.
451 // This operation is atomic.
452 Status AddLibrary(const FunctionDefLibrary& lib_def) TF_LOCKS_EXCLUDED(mu_);
453
454 // If the gradient function for 'func' is specified explicitly in
455 // the library, returns the gradient function name. Otherwise,
456 // returns an empty string.
457 std::string FindGradient(const std::string& func) const
458 TF_LOCKS_EXCLUDED(mu_);
459
460 // OpRegistryInterface method. Useful for constructing a Graph.
461 //
462 // If "op" is defined in the library, returns its signature.
463 // Otherwise, assume "op" is a primitive op and returns its op
464 // signature and shape inference function.
465 //
466 // NB: This function outputs a borrowed pointer, which can be invalidated by a
467 // subsequent call to `ReplaceFunction()` with the given name.
468 Status LookUp(const std::string& op_type_name,
469 const OpRegistrationData** op_reg_data) const override
470 TF_LOCKS_EXCLUDED(mu_);
471
472 // Generates new function name with the specified prefix that is unique
473 // across this library.
474 std::string UniqueFunctionName(StringPiece prefix) const
475 TF_LOCKS_EXCLUDED(mu_);
476
477 // Given a node def 'ndef', inspects attributes of the callee
478 // function to derive the attribute 'value' for 'attr'. Returns OK
479 // iff the attribute is given by the function's definition.
480 // TODO(irving): Remove; keep only the const Node& version.
481 template <typename T>
482 Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const;
483
484 // Given a node, inspects attributes of the callee function to derive the
485 // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
486 // function's definition.
487 template <typename T>
488 Status GetAttr(const Node& node, const std::string& attr, T* value) const;
489
490 // Returns a proto representation of the state of this function library.
491 FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_);
492
num_functions()493 size_t num_functions() const {
494 tf_shared_lock l(mu_);
495 return function_defs_.size();
496 }
497
498 // Returns all the function names in the FunctionLibraryDefinition.
499 std::vector<string> ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_);
500
default_registry()501 const OpRegistryInterface* default_registry() const {
502 return default_registry_;
503 }
504
505 // Returns a copy of `*this` with only the subset of functions that are
506 // reachable from the nodes of `graph` or `func`.
507 FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
508 FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
509
510 // Copies the function named `func` from `other` to this
511 // FunctionLibraryDefinition.
512 // REQUIRES: `this->default_registry() == other.default_registry()`.
513 // Returns OK on success, or error otherwise. This is a no-op if a function
514 // name `func` already exists in this function library, and has the same
515 // implementation as in `other`. If the implementations conflict, an invalid
516 // argument error is returned.
517 Status CopyFunctionDefFrom(const std::string& func,
518 const FunctionLibraryDefinition& other)
519 TF_LOCKS_EXCLUDED(mu_);
520
521 // Returns graph with debug stack traces for the given function, or `nullptr`
522 // if none found.
GetStackTraces(const std::string & func_name)523 const StackTracesMap& GetStackTraces(const std::string& func_name) const {
524 tf_shared_lock l(mu_);
525 std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name);
526 if (entry) {
527 return entry->stack_traces;
528 }
529 static const auto* empty_map = new StackTracesMap;
530 return *empty_map;
531 }
532
533 private:
534 // Shape inference for functions is handled separately by ShapeRefiner.
535
536 struct FunctionDefAndOpRegistration {
537 explicit FunctionDefAndOpRegistration(
538 const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {});
539
540 const FunctionDef fdef;
541 const OpRegistrationData op_registration_data;
542 const StackTracesMap stack_traces;
543 };
544
545 std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
546 const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_);
547 std::string FindGradientHelper(const std::string& func) const
548 TF_SHARED_LOCKS_REQUIRED(mu_);
549
550 Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
551 bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
552
553 // Same as AddFunctionDef/AddGradientDef except these methods set
554 // `added` to true if the `fdef`/`grad` were actually added to this.
555 Status AddFunctionDefHelper(const FunctionDef& fdef,
556 const StackTracesMap& stack_traces, bool* added)
557 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
558 Status AddGradientDefHelper(const GradientDef& grad, bool* added)
559 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
560
561 // Helper function for GetAttr. Returns the FunctionDef* to get the
562 // attr from.
563 const FunctionDef* GetAttrImpl(const NodeDef& ndef) const
564 TF_LOCKS_EXCLUDED(mu_);
565
566 // Remove all functions in `funcs` and all gradients of functions in
567 // `funcs_with_grads` from this library.
568 void Remove(const std::vector<string>& funcs,
569 const std::vector<string>& funcs_with_grads)
570 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
571
572 // Remove `func` from the library. Returns non-OK Status unless `func` is in
573 // the library. This should only be called when there is a guarantee that the
574 // function being removed hasn't been retrieved with `Find`.
575 Status RemoveFunctionHelper(const std::string& func)
576 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
577
578 // Remove gradient of function `func` from the library. Returns non-OK Status
579 // unless `func` has a gradient.
580 Status RemoveGradient(const std::string& func)
581 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
582
583 mutable mutex mu_;
584 const OpRegistryInterface* const default_registry_;
585 gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>>
586 function_defs_ TF_GUARDED_BY(mu_);
587 gtl::FlatMap<string, string> func_grad_ TF_GUARDED_BY(mu_);
588 };
589
590 // Forward declare. Defined in common_runtime/function.h
591 struct FunctionBody;
592
593 // Forward declare. Defined in common_runtime/device.h
594 class Device;
595 // Forward declare. Defined in common_runtime/device_mgr.h
596 class DeviceMgr;
597
598 // Index of an _Arg node.
599 struct FunctionArgIndex {
FunctionArgIndexFunctionArgIndex600 explicit FunctionArgIndex(const int index) : index(index) {}
FunctionArgIndexFunctionArgIndex601 FunctionArgIndex(const int index, const int sub_index)
602 : index(index), sub_index(sub_index) {}
603
604 // The value of the attribute "Index" of the _Arg node.
605 int index;
606 // Set only when the _Arg node represents multiple arguments (e.g. an _Arg
607 // node is replicated to multiple devices/subgraphs). Use sub-index to
608 // distinguish arguments with the same index.
609 int sub_index = -1;
610 };
611
612 class FunctionLibraryRuntime {
613 public:
~FunctionLibraryRuntime()614 virtual ~FunctionLibraryRuntime() {}
615
616 // Instantiate a function with the given "attrs".
617 //
618 // Returns OK and fills in "handle" if the instantiation succeeds.
619 // Otherwise returns an error and "handle" is undefined.
620 struct InstantiateOptions {
621 // The canonical device name of the device on which the function
622 // should be instantiated. If empty, the function will be
623 // instantiated on the local device.
624 std::string target;
625
626 // Should the function be instantiated as a multi-device function?
627 bool is_multi_device_function = false;
628
629 // If true, graph passes will be skipped when instantiating the function
630 // since they have already run on the main function side.
631 bool is_component_function = false;
632
633 // For multi-device functions, a vector of canonical device names for
634 // function's inputs. The device of resource inputs must be the device
635 // backing the resource, not the CPU device backing the resource handle.
636 // Must have the same length as number of inputs to the function.
637 std::vector<string> input_devices;
638
639 // For multi-device functions, a vector of canonical device names for
640 // function's outputs.
641 //
642 // (a) If specified (must have the same length as number of outputs):
643 //
644 // Specified devices will be assigned to Retval nodes inserted into the
645 // function body graph in place of function outputs. It is allowed to
646 // specify output device as empty string, in this case Retval device
647 // assignment will be inferred later when function graph will be placed
648 // before partitioning (this is required for resource outputs). Placer will
649 // respect colocation constraints.
650 //
651 // (b) If not specified:
652 //
653 // Function runtime will infer Retval device by following input edges, until
654 // it will reach a node with a device specification. This device
655 // specification must identify a unique device, i.e. a general specification
656 // like "job:foo" matching multiple devices will result in an error.
657 //
658 // IMPORTANT: Resource outputs
659 //
660 // Multi device functions might return resources on a devices different from
661 // the function call device. If output device is not specified for the
662 // resource output, and node producing that resource is a function call,
663 // runtime will leave device specification empty and will rely on Placer to
664 // infer correct device.
665 std::vector<string> output_devices;
666
667 // If set, it indicates the original output indices of a component function.
668 absl::optional<std::vector<int>> ret_indices = absl::nullopt;
669
670 // Maps from a CompositeDevice name to a list of underlying physical
671 // devices.
672 absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
673
674 // This interface is EXPERIMENTAL and subject to change.
675 //
676 // For multi-device functions, a mapping from _Arg node index to type and
677 // shape for input resources.
678 // REQUIRES: if input_resource_dtypes_and_shapes.count(i) > 0 then i-th
679 // argument type must be DT_RESOURCE.
680 std::unordered_map<int, DtypeAndPartialTensorShape>
681 input_resource_dtypes_and_shapes;
682
683 // This interface is EXPERIMENTAL and subject to change.
684 //
685 // If non-null, the runtime will use `lib_def` to resolve function(s) named
686 // in `function_name` and `attrs`. Otherwise, the runtime will use its
687 // internal library.
688 //
689 // NOTE(mrry): If provided, all functions defined in `lib_def` must be
690 // self-contained, and cannot refer to functions defined in other libraries.
691 const FunctionLibraryDefinition* lib_def = nullptr;
692
693 // This interface is EXPERIMENTAL and subject to change.
694 //
695 // If non-empty, the runtime will use `state_handle` to identify
696 // cached state related the instantiated function. Two functions
697 // of the same name and attrs, instantiated with the same
698 // `state_handle` will have the same handle and share the same
699 // state (in stateful kernels); and two functions with different
700 // values for `state_handle` will have independent state.
701 std::string state_handle;
702
703 // This interface is EXPERIMENTAL and subject to change.
704 //
705 // Instantiates the function using an executor of the given type. If empty,
706 // the default TensorFlow executor will be used.
707 std::string executor_type;
708
709 // If true, the runtime will attempt to create kernels for the function at
710 // instantiation time, rather than on the first run. This can be used to
711 // surface errors earlier.
712 bool create_kernels_eagerly = false;
713
714 // This interface is EXPERIMENTAL and subject to change.
715 //
716 // Instantiates the function with the provided config_proto.
717 ConfigProto config_proto;
718
719 // If provided, this optimization function will be invoked before
720 // the placer for multi-device functions.
721 std::function<Status(std::vector<string> /*ret_node_names*/,
722 std::vector<string> /*keep_node_names*/,
723 FunctionLibraryDefinition*, const DeviceSet&,
724 Device* /*cpu_device*/, std::unique_ptr<Graph>*)>
725 optimize_graph_fn;
726
727 // If set, partitioned functions will be added to `graph_collector`.
728 // `graph_collector` must be alive during the call to Instantiate.
729 GraphCollector* graph_collector = nullptr;
730
731 // Indicates whether the multi-device function backend should default the
732 // placement of ops without request device to `target`.
733 bool default_device_to_target = true;
734
735 // If true, the optimized Graph will be stored so that
736 // `FunctionLibraryRuntime::DebugString(handle)` contains the optimized
737 // Graph. Otherwise, the unoptimized function Graph will be returned.
738 bool include_optimized_graph_in_debug_string = false;
739 };
740 typedef uint64 Handle;
741 virtual Status Instantiate(const std::string& function_name, AttrSlice attrs,
742 const InstantiateOptions& options,
743 Handle* handle) = 0;
Instantiate(const std::string & function_name,AttrSlice attrs,Handle * handle)744 Status Instantiate(const std::string& function_name, AttrSlice attrs,
745 Handle* handle) {
746 auto opts = absl::make_unique<InstantiateOptions>();
747 return Instantiate(function_name, attrs, *opts, handle);
748 }
749
750 // Releases state associated with the handle.
751 virtual Status ReleaseHandle(Handle handle) = 0;
752
753 // Returns the function body for the instantiated function given its
754 // handle 'h'. Returns nullptr if "h" is not found.
755 //
756 // *this keeps the ownership of the returned object, which remains alive
757 // as long as *this.
758 virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
759
760 // Returns the return types for the function identified by handle `h`.
761 virtual Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0;
762
763 // Asynchronously invokes the instantiated function identified by
764 // "handle".
765 //
766 // If function execution succeeds, "done" is called with OK and
767 // "*rets" is filled with the function's return values. Otherwise,
768 // "done" is called with an error status.
769 //
770 // Does not take ownership of "rets".
771 // In the cross-process scenario, runner isn't used for making the Async
772 // RPC calls.
773 struct Options {
OptionsOptions774 Options() {}
OptionsOptions775 explicit Options(const int64 step_id) : step_id(step_id) {}
776 // Choose a step ID that is guaranteed not to clash with any
777 // Session-generated step ID. DirectSession only generates
778 // non-negative step IDs (contiguous, starting from 0), and
779 // MasterSession generates 56-bit random step IDs whose MSB is
780 // always 0, so a negative random step ID should suffice.
781 const int64 step_id = -std::abs(static_cast<int64>(random::New64()));
782
783 // op_id of the function running in eager mode. Set when we want to copy
784 // remote outputs lazily. All components of a remote multi-device function
785 // should use the same op_id, in order to correctly map remote output
786 // tensors to the remote TensorHandles in the default device.
787 absl::optional<int64> op_id = absl::nullopt;
788
789 RendezvousInterface* rendezvous = nullptr;
790 CancellationManager* cancellation_manager = nullptr;
791 CollectiveExecutor* collective_executor = nullptr;
792 ScopedStepContainer* step_container = nullptr;
793 StepStatsCollectorInterface* stats_collector = nullptr;
794
795 std::function<void(std::function<void()>)>* runner = nullptr;
796
797 // Parameters for remote function execution.
798 bool remote_execution = false;
799 std::string source_device = ""; // Fully specified device name.
800
801 // Allocator attributes specifying where the args are / rets should be put.
802 // These should either be {} or match the length of args / retvals. If {},
803 // the default allocator attributes will be assumed for all args / retvals.
804 std::vector<AllocatorAttributes> args_alloc_attrs;
805 std::vector<AllocatorAttributes> rets_alloc_attrs;
806
807 // If true, we create a new IntraProcessRendezvous, else use the existing
808 // one.
809 bool create_rendezvous = false;
810
811 // If True, allow returning dead tensors.
812 bool allow_dead_tensors = false;
813
814 // If True, hint that all kernels should be treated as "inexpensive", and
815 // hence executed on the scheduling thread.
816 bool run_all_kernels_inline = false;
817
818 // Returns a human readable representation of this.
819 std::string DebugString() const;
820 };
821 typedef std::function<void(const Status&)> DoneCallback;
822 virtual void Run(const Options& opts, Handle handle,
823 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
824 DoneCallback done) = 0;
825 virtual void Run(const Options& opts, Handle handle,
826 CallFrameInterface* call_frame, DoneCallback done) = 0;
827
828 virtual Status RunSync(Options opts, Handle handle,
829 gtl::ArraySlice<Tensor> args,
830 std::vector<Tensor>* rets) = 0;
831 virtual Status RunSync(Options opts, Handle handle,
832 CallFrameInterface* call_frame) = 0;
833
834 // Creates a "kernel" for the given NodeProperties "props".
835 //
836 // If succeeds, returns OK and the caller takes the ownership of the
837 // returned "*kernel". Otherwise, returns an error.
838 virtual Status CreateKernel(
839 const std::shared_ptr<const NodeProperties>& props,
840 OpKernel** kernel) = 0;
841
842 // Returns true iff the function named `function_name` is stateful.
843 //
844 // NOTE(mrry): This method assumes that the runtime is associated with a
845 // default function library, and looks up `function_name` in that library.
846 // It does not support overriding the function library.
847 virtual bool IsStateful(const std::string& function_name) const = 0;
848
849 // Returns the device on which the function executes.
850 virtual Device* device() = 0;
851 virtual const Device* device() const = 0;
852
853 // Returns the default runner in which the ops should be launched. If the
854 // device on which the function executes has a private thread pool, return
855 // runner on the device local thread pool.
856 virtual std::function<void(std::function<void()>)>* runner() = 0;
857
858 // Get the DeviceMgr from which the device was obtained.
859 virtual const DeviceMgr* device_mgr() const = 0;
860
861 // Returns the function library definition that backs this runtime.
862 //
863 // NOTE(mrry): The returned library definition is the default function library
864 // for this runtime. The caller may override the function library used by the
865 // runtime to instantiate functions, which will not be reflected in the return
866 // value of this function.
867 virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
868 const = 0;
869
870 // Returns the environment on which the function executes.
871 virtual Env* env() = 0;
872
873 // Returns the ConfigProto passed to the session used to create the function.
874 virtual const ConfigProto* const config_proto() = 0;
875
876 // Returns a debug string showing the definition of the function of
877 // 'handle'.
878 virtual std::string DebugString(Handle handle) = 0;
879
880 // Returns the graph version number.
881 virtual int graph_def_version() const = 0;
882
883 typedef uint64 LocalHandle;
884
885 // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to
886 // the caller), FunctionLibraryRuntime (owned by the returned
887 // ProcessFunctionLibraryRuntime), FunctionLibraryDefinition (transferring
888 // ownership to the caller). Note that both the ProcessFunctionLibraryRuntime
889 // and FunctionLibraryRuntime borrow a pointer to the
890 // FunctionLibraryDefinition and so the FunctionLibraryDefinition should
891 // outlive both.
892 //
893 // The `skip_flib_def` argument controls whether the method should clone the
894 // FunctionLibraryDefinition (default behavior) or return an empty function
895 // library. The latter is used by tf.data, which manages
896 // FunctionLibraryDefinitions for its functions independently (and passes
897 // these into the FunctionLibraryRuntime through an overlay), to avoid linear
898 // runtime w.r.t. to number of functions in the current function library.
899 virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
900 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
901 FunctionLibraryRuntime** out_flr,
902 bool skip_flib_def = false) = 0;
903
904 // Returns the name of the executor class (in the sense of
905 // `ExecutorFactory::GetFactory()`) that will be used based on the given
906 // dynamic `options` and static `attrs`. If none is specified, this method
907 // will return an empty string, which leaves the decision up to the runtime.
908 static std::string ExecutorType(const InstantiateOptions& options,
909 AttrSlice attrs);
910 };
911
912 // Returns the device of the `arg_index`-th function input. Update
913 // `composite_devices` if the input device is a composite device.
914 std::string GetFunctionResourceInputDevice(
915 const Tensor& input, const int arg_index, const FunctionDef& function_def,
916 absl::flat_hash_map<string, std::vector<string>>* composite_devices);
917
918 // Returns a canonicalized string for the instantiation of the
919 // function of the given "name", attributes "attrs", and "options".
920 //
921 // The returned string is guaranteed to be stable within one address
922 // space. But it may be change as the implementation
923 // evolves. Therefore, it should not be persisted or compared across
924 // address spaces.
925 std::string Canonicalize(
926 const std::string& funcname, AttrSlice attrs,
927 const FunctionLibraryRuntime::InstantiateOptions& options);
928 std::string Canonicalize(const std::string& funcname, AttrSlice attrs);
929
930 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
931 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
932
933 class CustomKernelCreator {
934 public:
~CustomKernelCreator()935 virtual ~CustomKernelCreator() {}
936
937 // Given a NodeDef 'node_def' and the function library runtime 'flr',
938 // validate if the class supports creating such a kernel.
939 virtual bool CanCreateKernel(
940 const FunctionLibraryRuntime& flr,
941 const std::shared_ptr<const NodeProperties>& props) const = 0;
942
943 // Given a supported NodeDef, returns a kernel that computes the node.
944 virtual Status CreateKernel(
945 FunctionLibraryRuntime* flr,
946 const std::shared_ptr<const NodeProperties>& props,
947 std::unique_ptr<OpKernel>* kernel) const = 0;
948 };
949
950 typedef
951 #if !defined(IS_MOBILE_PLATFORM)
952 absl::variant<Tensor, eager::RemoteTensorHandle*>
953 FunctionArg;
954 #else
955 absl::variant<Tensor>
956 FunctionArg;
957 #endif
958
959 // Either a local tensor or the shape of a remote tensor.
960 typedef absl::variant<Tensor, TensorShape> FunctionRet;
961
962 // Used to instantiate and run functions in a distributed system.
963 class DistributedFunctionLibraryRuntime {
964 public:
~DistributedFunctionLibraryRuntime()965 virtual ~DistributedFunctionLibraryRuntime() {}
966
967 // Instantiate a function on a remote target specified in `options.target`, by
968 // sending the name and definition of the function to the remote worker. The
969 // local `handle` is filled for the instantiated function data and can be used
970 // for subsequent run function calls on the remote target.
971 virtual void Instantiate(
972 const std::string& function_name,
973 const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
974 const FunctionLibraryRuntime::InstantiateOptions& options,
975 FunctionLibraryRuntime::LocalHandle* handle,
976 FunctionLibraryRuntime::DoneCallback done) = 0;
977
978 // Run an instantiated remote function (specified by `handle`) with a list of
979 // input Tensors in `args` and get its output Tensors in `rets`. The input
980 // tensor data will be sent with the function execution request, and must be
981 // available on the current caller side.
982 // opts.runner isn't used for execution.
983 virtual void Run(const FunctionLibraryRuntime::Options& opts,
984 FunctionLibraryRuntime::LocalHandle handle,
985 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
986 FunctionLibraryRuntime::DoneCallback done) = 0;
987
988 // Run an instantiated remote function (specified by `handle`) with a list of
989 // input Tensors or RemoteTensorHandles as `args` and get its output Tensors
990 // or TensorShapes in `rets`. When using RemoteTensorHandles as function
991 // inputs or TensorShapes as outputs, the corresponding tensor data will be
992 // resolved on the remote worker, so it is not required to be locally
993 // available on the caller side. Using RemoteTensorHandle inputs is not
994 // supported in TensorFlow v1 runtime.
995 virtual void Run(const FunctionLibraryRuntime::Options& opts,
996 FunctionLibraryRuntime::LocalHandle handle,
997 gtl::ArraySlice<FunctionArg> args,
998 std::vector<FunctionRet>* rets,
999 FunctionLibraryRuntime::DoneCallback done) = 0;
1000
1001 // Clean up a previously instantiated function on remote worker.
1002 virtual void CleanUp(uint64 step_id,
1003 FunctionLibraryRuntime::LocalHandle handle,
1004 FunctionLibraryRuntime::DoneCallback done) = 0;
1005
1006 // DeviceMgr with *all* available devices (i.e., local and remote).
1007 virtual DeviceMgr* remote_device_mgr() const = 0;
1008 };
1009
1010 // Extracts the actual type from "attr_values" based on its definition
1011 // "arg_def".
1012 //
1013 // If "arg_def" is a N*T type, *is_type_list is set to false, and
1014 // *dtypes is set to be a vector of size N and each element is T.
1015 //
1016 // If "arg_def" is a list(type), *is_type_list is set to true, and
1017 // *dtypes is set to be a vector of types specified in attrs for
1018 // arg_def.
1019 //
1020 // Otherwise (arg_def is a simple type T), *is_type_list is set to
1021 // false, and *dtypes is set to a single element vector, whose only
1022 // element is T.
1023 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
1024 bool* is_type_list, DataTypeVector* dtypes);
1025
1026 // To register a gradient function for a builtin op, one should use
1027 // REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
1028 //
1029 // Typically, the c++ grad factory is a plan function that can be
1030 // converted into ::tensorflow::gradient::Creator, which is
1031 // std::function<Status(const AttrSlice&, FunctionDef*)>.
1032 //
1033 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
1034 // definition of a brain function which compute the gradient for the
1035 // <op_name> when the <op_name> is instantiated with the given attrs.
1036 //
1037 // E.g.,
1038 //
1039 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
1040 // bool transpose_a;
1041 // TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
1042 // bool transpose_b;
1043 // TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
1044 // DataType dtype;
1045 // TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
1046 // if (!transpose_a && !transpose_b) {
1047 // *g = FunctionDefHelper::Define(
1048 // "MatMulGrad",
1049 // {"x:T ", "y:T", "dz:T"}, // Inputs to this function
1050 // {"dx:T", "dy:T"}, // Outputs from this function
1051 // {"T: {float, double}"}, // Attributes needed by this function
1052 // {
1053 // {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
1054 // {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
1055 // {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
1056 // {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
1057 // });
1058 // } else {
1059 // ... ...
1060 // }
1061 // return Status::OK();
1062 // }
1063 //
1064 // NOTE: $T is substituted with the type variable "T" when the
1065 // gradient function MatMul is instantiated.
1066 //
1067 // TODO(zhifengc): Better documentation somewhere.
1068
1069 // Macros to define a gradient function factory for a primitive
1070 // operation.
1071 #define REGISTER_OP_GRADIENT(name, fn) \
1072 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
1073
1074 #define REGISTER_OP_NO_GRADIENT(name) \
1075 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
1076
1077 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
1078 REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
1079
1080 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
1081 static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \
1082 SHOULD_REGISTER_OP_GRADIENT && \
1083 ::tensorflow::gradient::RegisterOp(name, fn)
1084
1085 namespace gradient {
1086 // Register a gradient creator for the "op".
1087 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
1088 bool RegisterOp(const std::string& op, Creator func);
1089
1090 // Returns OK the gradient creator for the "op" is found (may be
1091 // nullptr if REGISTER_OP_NO_GRADIENT is used.
1092 Status GetOpGradientCreator(const std::string& op, Creator* creator);
1093 }; // namespace gradient
1094
1095 // Declare explicit instantiations of GetAttr
1096 #define GET_ATTR(T) \
1097 extern template Status FunctionLibraryDefinition::GetAttr( \
1098 const Node&, const string&, T*) const; \
1099 extern template Status FunctionLibraryDefinition::GetAttr( \
1100 const NodeDef&, const string&, T*) const;
1101 GET_ATTR(string)
1102 GET_ATTR(bool)
1103 #undef GET_ATTR
1104
1105 } // end namespace tensorflow
1106
1107 #endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
1108