• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
18 
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/function.pb.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/grappler/mutable_graph_view.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 namespace function_utils {
33 // This namespace contains utility functions for querying and modifying
34 // FunctionDefs.
35 
36 // Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
37 // have the format node_name:node_output:position (if they derive from nodes),
38 // or input_name (if they derive from an argument).
39 struct FunctionDefTensorDesc {
40   FunctionDefTensorDesc() = default;
41 
42   FunctionDefTensorDesc(const string& node_name, const string& output,
43                         int position);
44 
45   // Parses node_name:node_output:position string into its components.
46   explicit FunctionDefTensorDesc(const string& input);
47 
48   // TODO(rachelim): Add provisions to deal with special formats, like how
49   // GrapplerFunctionItem expands node output range if position is not defined
50   string full_str;
51   string node_name;
52   string node_output;
53   int position = -1;
54 };
55 
56 // Replaces all references to `from` tensor in func's nodes' inputs and retvals
57 // to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
58 void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
59 
60 // Adds a function output to the function def, ensuring that the output key
61 // is unique, and maps to output_tensor_name in the ret dict.
62 void AddFunctionOutputWithUniqueName(StringPiece prefix,
63                                      StringPiece output_tensor_name,
64                                      FunctionDef* function, DataType dt);
65 
66 // Adds a node to a FunctionDef.
67 NodeDef* AddNode(StringPiece name, StringPiece op,
68                  const std::vector<string>& inputs,
69                  const std::vector<std::pair<string, AttrValue>>& attributes,
70                  FunctionDef* fd);
71 
72 // Checks whether the function contains a node with the given name.
73 bool ContainsFunctionNodeWithName(StringPiece name,
74                                   const FunctionDef& function);
75 
76 // Checks whether the function contains a node with the given op.
77 bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
78 
79 // Checks whether the function contains an output with the given name.
80 bool ContainsFunctionOutputWithName(StringPiece name,
81                                     const FunctionDef& function);
82 
83 // Returns the index of the function input with the given name or -1 if the
84 // function node does not exist.
85 int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
86 
87 // Returns the index of the function output with the given name or -1 if the
88 // function node does not exist.
89 int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
90 
91 // Returns the index of the function node with the given name or -1 if the
92 // function node does not exist.
93 int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
94 
95 // Returns the index of the function node with the given op or -1 if the
96 // function node does not exist.
97 int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
98 
99 // Sets the function node name using the `prefix` as a prefix while guaranteeing
100 // the name is unique across the functions nodes.
101 void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
102                                NodeDef* node);
103 
104 // Checks if the function is stateful by checking the function graph for
105 // stateful ops. Because the "If" and "While" ops are conservatively marked as
106 // stateful, the check recurses into their graph to determine whether they are
107 // actually stateful. The `skip_assert` argument determines whether the "Assert"
108 // op should be treated as stateful or not.
109 bool IsFunctionStateful(const FunctionLibraryDefinition& library,
110                         const FunctionDef& function_def,
111                         bool skip_assert = false);
112 
113 // Checks if the node is stateful. Because the "If" or "While" ops are
114 // conservatively marked as stateful, the check recurses into their graph to
115 // determine whether they are actually stateful. The `skip_assert` argument
116 // determines whether the "Assert" op  should be treated as stateful or not.
117 bool IsNodeStateful(const FunctionLibraryDefinition& library,
118                     const NodeDef& node, bool skip_assert = false);
119 
120 }  // end namespace function_utils
121 }  // end namespace grappler
122 }  // end namespace tensorflow
123 
124 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
125