• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
18 
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/function.pb.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/grappler/mutable_graph_view.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 namespace graph_utils {
34 
35 // Returns the index of the first element in collection that fulfills predicate.
36 // If no such element exists, returns -1.
37 template <typename Predicate, typename Collection>
GetFirstElementIndexWithPredicate(const Predicate & predicate,const Collection & collection)38 int GetFirstElementIndexWithPredicate(const Predicate& predicate,
39                                       const Collection& collection) {
40   unsigned idx = 0;
41   for (auto&& element : collection) {
42     if (predicate(element)) {
43       return idx;
44     }
45     idx++;
46   }
47   return -1;
48 }
49 
50 // Adds a node to the graph.
51 NodeDef* AddNode(StringPiece name, StringPiece op,
52                  const std::vector<string>& inputs,
53                  const std::vector<std::pair<string, AttrValue>>& attributes,
54                  MutableGraphView* graph);
55 
56 // Adds Placeholder node for given type.
57 NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph);
58 
59 // Adds a Const node with the given value to the graph.
60 template <typename T>
AddScalarConstNode(T v,MutableGraphView * graph)61 NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
62   // is_same is an idiomatic hack for making it compile if not instantiated.
63   // Replacing with false will result in a compile-time error.
64   static_assert(!std::is_same<T, T>::value,
65                 "Invalid specialization of this method for type T.");
66   return {};
67 }
68 
69 template <>
70 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph);
71 template <>
72 NodeDef* AddScalarConstNode(double v, MutableGraphView* graph);
73 template <>
74 NodeDef* AddScalarConstNode(float v, MutableGraphView* graph);
75 template <>
76 NodeDef* AddScalarConstNode(int v, MutableGraphView* graph);
77 template <>
78 NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph);
79 template <>
80 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph);
81 
82 // Checks whether the two graphs are the same.
83 bool Compare(const GraphDef& g1, const GraphDef& g2);
84 
85 // Checks whether the graph contains a node with the given name.
86 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph);
87 
88 // Checks whether the library contains a function with the given name.
89 bool ContainsGraphFunctionWithName(StringPiece name,
90                                    const FunctionDefLibrary& library);
91 
92 // Checks whether the graph contains a node with the given op.
93 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
94 
95 // Returns the index of the node with the given name or -1 if the node does
96 // not exist.
97 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph);
98 
99 // Returns the index of the function with the given name or -1 if the function
100 // does not exist.
101 int FindGraphFunctionWithName(StringPiece name,
102                               const FunctionDefLibrary& library);
103 
104 // Returns the index of the first node with the given op or -1 if no such  node
105 // exists.
106 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
107 
108 // Gets the 0th input to a node in the graph.
109 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
110 
111 // Gets the ith input to a node in the graph.
112 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
113                       int64 i);
114 
115 // Returns the list of indices of all nodes with the given op or empty list if
116 // no such node exists.
117 std::vector<int> FindAllGraphNodesWithOp(const string& op,
118                                          const GraphDef& graph);
119 
120 // Sets the node name using `prefix` as a prefix while guaranteeing the name
121 // is unique across the graph.
122 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
123 
124 // Sets the function name using the `prefix` name as a prefix while guaranteeing
125 // the name is unique across the function library.
126 void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
127                                 FunctionDef* function);
128 
129 // Copies attribute having name `attribute_name` from node `from` to node
130 // `to_node`.
131 void CopyAttribute(const string& attribute_name, const NodeDef& from,
132                    NodeDef* to_node);
133 
134 // Concatenates list attribute having name `attribute_name` from `first` and
135 // `second` node, setting it to `to_node`.
136 void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
137                          const NodeDef& second, NodeDef* to_node);
138 
139 // Checks that all nodes in the graphs have unique names, and sets their names
140 // to be unique if they are not already.  This is necessary as Graph does not
141 // have the provisions to deduplicate names, and name deduplication elsewhere
142 // in tensorflow happens in other layers (for example, in the Scope class of the
143 // C++ API). Note that the nodes in the graph are identified by their id,
144 // and renaming nodes does not mutate any edges.
145 Status EnsureNodeNamesUnique(Graph* g);
146 
147 // Returns the sink node (i.e. last node) in the graph.
148 Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node);
149 
150 }  // namespace graph_utils
151 }  // namespace grappler
152 }  // namespace tensorflow
153 
154 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
155