• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
18 
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/function.pb.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/grappler/mutable_graph_view.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 namespace function_utils {
33 // This namespace contains utility functions for querying and modifying
34 // FunctionDefs.
35 
36 // Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
37 // have the format node_name:node_output:position (if they derive from nodes),
38 // or input_name (if they derive from an argument).
39 struct FunctionDefTensorDesc {
40   FunctionDefTensorDesc() = default;
41 
42   FunctionDefTensorDesc(const string& node_name, const string& output,
43                         int position);
44 
45   // Parses node_name:node_output:position string into its components.
46   explicit FunctionDefTensorDesc(const string& input);
47 
48   // TODO(rachelim): Add provisions to deal with special formats, like how
49   // GrapplerFunctionItem expands node output range if position is not defined
50   string full_str;
51   string node_name;
52   string node_output;
53   int position = -1;
54 };
55 
56 // Replaces all references to `from` tensor in func's nodes' inputs and retvals
57 // to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
58 void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
59 
60 // Adds a function output to the function def, ensuring that the output key
61 // is unique, and maps to output_tensor_name in the ret dict.
62 void AddFunctionOutputWithUniqueName(StringPiece prefix,
63                                      StringPiece output_tensor_name,
64                                      FunctionDef* fdef, DataType dtype);
65 
66 // Adds an input to a FunctionDef.
67 OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef,
68                                DataType dtype);
69 
70 // Adds a node to a FunctionDef.
71 NodeDef* AddNode(StringPiece name, StringPiece op,
72                  const std::vector<string>& inputs,
73                  const std::vector<std::pair<string, AttrValue>>& attributes,
74                  FunctionDef* fd);
75 
76 // Checks whether the function contains a node with the given name.
77 bool ContainsFunctionNodeWithName(StringPiece name,
78                                   const FunctionDef& function);
79 
80 // Checks whether the function contains a node with the given op.
81 bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
82 
83 // Checks whether the function contains an output with the given name.
84 bool ContainsFunctionOutputWithName(StringPiece name,
85                                     const FunctionDef& function);
86 
87 // Returns the index of the function input with the given name or -1 if the
88 // function node does not exist.
89 int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
90 
91 // Returns the index of the function output with the given name or -1 if the
92 // function node does not exist.
93 int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
94 
95 // Returns the index of the function node with the given name or -1 if the
96 // function node does not exist.
97 int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
98 
99 // Returns the index of the function node with the given op or -1 if the
100 // function node does not exist.
101 int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
102 
103 // Sets the function node name using the `prefix` as a prefix while guaranteeing
104 // the name is unique across the functions nodes.
105 void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
106                                NodeDef* node);
107 
108 // Checks if the function is stateful by checking the function graph for
109 // stateful ops. Because the "If" and "While" ops are conservatively marked as
110 // stateful, the check recurses into their graph to determine whether they are
111 // actually stateful. The `skip_assert` argument determines whether the "Assert"
112 // op should be treated as stateful or not.
113 bool IsFunctionStateful(const FunctionLibraryDefinition& library,
114                         const FunctionDef& function_def,
115                         bool skip_assert = false);
116 
117 // Checks if the node is stateful. Because the "If" or "While" ops are
118 // conservatively marked as stateful, the check recurses into their graph to
119 // determine whether they are actually stateful. The `skip_assert` argument
120 // determines whether the "Assert" op  should be treated as stateful or not.
121 bool IsNodeStateful(const FunctionLibraryDefinition& library,
122                     const NodeDef& node, bool skip_assert = false);
123 
124 }  // end namespace function_utils
125 }  // end namespace grappler
126 }  // end namespace tensorflow
127 
128 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
129