• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/lib/hash/hash.h"
27 #include "tensorflow/core/platform/protobuf.h"
28 
29 namespace tensorflow {
30 
31 class Node;
32 struct NodeDebugInfo;
33 
34 // We forward declare protos so that kernels don't need to depend on them
35 class NodeDef;
36 class OpDef;
37 
38 // Name of the attribute used to encode node colocation constraints.
39 //
40 // Nodes can be co-located on the same device. Desire for explicit co-location
41 // is described by list(string) attribute containing the name of colocation
42 // groups.
43 extern const char* const kColocationAttrName;
44 
45 // String prefix applied to the operation name for colocation constraints.
46 extern const char* const kColocationGroupPrefix;
47 
48 // Produce a human-readable version of a Node or NodeDef that is more concise
49 // than a text-format proto.
50 string SummarizeNode(const Node& node);
51 string SummarizeNodeDef(const NodeDef& node_def);
52 string SummarizeAttrs(const NodeDef& node_def);
53 
54 // Produces a formatted string pattern from the node which can uniquely identify
55 // this node upstream to produce an informative error message. The pattern
56 // followed is: {{node <node_name>}}
57 string FormatNodeForError(const Node& node);
58 string FormatNodeDefForError(const NodeDef& node_def);
59 
60 // Merges the original node names from the debug information of 'from' to the
61 // debug information of 'to'.
62 void MergeDebugInfo(const NodeDebugInfo& from, Node* to);
63 void MergeDebugInfo(const NodeDebugInfo& from, NodeDef* to);
64 void MergeDebugInfo(const NodeDef& from, NodeDef* to);
65 
66 typedef protobuf::Map<string, AttrValue> AttrValueMap;
67 
68 // Adds an attr with name <name> and value <value> to *node_def.
69 // The type of the attr is based on the type of value.
70 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
71 void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
72 void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
73 void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
74 void AddNodeAttr(StringPiece name, int64 value, NodeDef* node_def);
75 void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
76 void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
77 void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
78 void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
79 void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
80                  NodeDef* node_def);
81 void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
82 void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
83 void AddNodeAttr(StringPiece name, const NameAttrList& value,
84                  NodeDef* node_def);
85 void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
86                  NodeDef* node_def);
87 void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
88                  NodeDef* node_def);
89 void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
90                  NodeDef* node_def);
91 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
92                  NodeDef* node_def);
93 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64> value,
94                  NodeDef* node_def);
95 void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
96                  NodeDef* node_def);
97 void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
98                  NodeDef* node_def);
99 void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
100                  NodeDef* node_def);
101 void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
102                  NodeDef* node_def);
103 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
104                  NodeDef* node_def);
105 void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
106                  NodeDef* node_def);
107 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
108                  NodeDef* node_def);
109 void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
110                  NodeDef* node_def);
111 void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
112                  NodeDef* node_def);
113 
114 // Version to workaround C++'s "perfect" forwarding not being able to
115 // forward {...} initialization.
116 template <class T>
AddNodeAttr(StringPiece name,std::initializer_list<T> value,NodeDef * node_def)117 void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
118                  NodeDef* node_def) {
119   AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
120 }
121 
122 // Adds an attr to an attr value map.
123 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
124 void AddAttr(StringPiece name, bool value, AttrValueMap* map);
125 
126 class AttrSlice {
127  public:
128   AttrSlice(const NodeDef& node_def);  // NOLINT(runtime/explicit)
129 
130   AttrSlice();  // Empty
131   explicit AttrSlice(const AttrValueMap* a);
132 
size()133   int size() const { return attrs_->size(); }
134 
135   // Returns the attr with attr_name if found.  Otherwise, returns
136   // nullptr.
137   const AttrValue* Find(StringPiece attr_name) const;
138 
139   // Returns the attr_value for attr_name if found. Otherwise, returns a
140   // NotFound status.
141   Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
142 
143   // Helper class to avoid allocations in EqualAttrs.
144   // TODO(irving): Will go away once NodeInfo is used.
145   struct Scratch {
146     string a;
147     string b;
148   };
149 
150   // Check if all attrs and attr values match.  Does not take defaults into
151   // account.
152   //
153   // TODO(irving): There is a bug in this routine inherited from its
154   // OptimizerCSE::EqualAttrs precedecessor.  The same tensor attr can be
155   // represented in more than one way as an AttrValue, since TensorProto is
156   // not 1-1.  This bug will go away once I replace everything with NodeInfo,
157   // which stores a Tensor object directly.  The Scratch object will also go
158   // away.
159   bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
160 
161   // If this AttrSlice has an attached NodeDef, summarize it.  This is for
162   // error messages only: we intentionally do not provide direct access to the
163   // NodeDef, since it is not always there.
164   string SummarizeNode() const;
165 
166   // Iteration over all attrs
begin()167   AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
end()168   AttrValueMap::const_iterator end() const { return attrs_->end(); }
169 
170  private:
171   const NodeDef* ndef_;
172   const AttrValueMap* attrs_;
173 };
174 
175 // Return true if the attr with the name attr_name is defined in node_def.
176 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
177 
178 // Look up the attr with name attr_name and set *value to its value.  If no
179 // attr with attr_name is found in node_def, or the attr does not have
180 // a matching type, a non-ok status will be returned.
181 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
182                    string* value);  // type: "string"
183 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
184                    int64* value);  // type: "int"
185 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
186                    int32* value);  // type: "int"
187 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
188                    float* value);  // type: "float"
189 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
190                    bool* value);  // type: "bool"
191 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
192                    DataType* value);  // type: "type"
193 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
194                    TensorShapeProto* value);  // type: "shape"
195 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
196                    TensorShape* value);  // type: "shape"
197 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
198                    PartialTensorShape* value);  // type: "shape"
199 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
200                    Tensor* value);  // type: "tensor"
201 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
202                    std::vector<string>* value);  // type "list(string)"
203 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
204                    std::vector<int64>* value);  // type "list(int)"
205 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
206                    std::vector<int32>* value);  // type "list(int)"
207 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
208                    std::vector<float>* value);  // type "list(float)"
209 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
210                    std::vector<bool>* value);  // type "list(bool)"
211 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
212                    std::vector<DataType>* value);  // type "list(type)"
213 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
214                    DataTypeVector* value);  // type "list(type)"
215 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
216                    std::vector<TensorShapeProto>* value);  // type "list(shape)"
217 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
218                    std::vector<TensorShape>* value);  // type "list(shape)"
219 Status GetNodeAttr(
220     const AttrSlice& attrs, StringPiece attr_name,
221     std::vector<PartialTensorShape>* value);  // type "list(shape)"
222 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
223                    std::vector<Tensor>* value);  // type: "list(tensor)"
224 
225 // This version avoids copying the TensorProto.
226 // REQUIRES: Must not use *value beyond the lifetime of node_def.
227 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
228                    const TensorProto** value);  // type: "tensor"
229 
230 // This version avoids copying the NameAttrList.
231 // REQUIRES: Must not use *value beyond the lifetime of node_def.
232 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
233                    const NameAttrList** value);  // type: "func"
234 
235 // These versions copies the NameAttrList(s).
236 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
237                    NameAttrList* value);  // type: "func"
238 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
239                    std::vector<NameAttrList>* value);  // type: "list(func)"
240 
241 // Look up the attr with name attr_name and set *value to its value.  If no
242 // attr with attr_name is found in node_def, or the attr does not have
243 // a matching type, false is returned.
244 bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
245                        string* value);  // type: "string"
246 bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,
247                        std::vector<string>* value);  // type: "string"
248 
249 // Look up the attr with name attr_name and return a reference to its value.
250 // If no attr with attr_name is found in node_def, or the attr does not have
251 // a matching type, a reference to an empty string is returned.
252 // REQUIRES: Must not use the returned value beyond the lifetime of node_def.
253 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
254 
255 // Computes the input type for a specific node input.
256 // REQUIRES: ValidateOpDef(op_def).ok()
257 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
258                         int input_port, DataType* input_type);
259 // Computes the input types for a specific node.
260 // REQUIRES: ValidateOpDef(op_def).ok()
261 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
262                          DataTypeVector* inputs);
263 // Computes the output type for a specific node output.
264 // REQUIRES: ValidateOpDef(op_def).ok()
265 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
266                          int output_port, DataType* output_type);
267 // Computes the output types for a specific node.
268 // REQUIRES: ValidateOpDef(op_def).ok()
269 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
270                           DataTypeVector* outputs);
271 // Computes the input and output types for a specific node.
272 // REQUIRES: ValidateOpDef(op_def).ok()
273 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
274                          DataTypeVector* inputs, DataTypeVector* outputs);
275 // Computes the number of outputs for a specific node.
276 // REQUIRES: ValidateOpDef(op_def).ok()
277 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
278                          int* num_outputs);
279 
280 // Validates that the NodeDef:
281 // * Defines all expected attrs from the OpDef.
282 // * All attrs satisfies constraints from the OpDef.
283 // * Has a signature matching SignatureForNode().
284 // etc.
285 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
286 
287 // Computes the mapping from input/output argument name to the
288 // corresponding input/output index range.  For example,
289 // input "foo" corresponds to input indices
290 //   [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
291 // NOTE(mrry): To reduce allocations when the map is used and save
292 // space, the returned `NameRangeMap` objects borrow the input/output
293 // argument names from `op_def`. The `op_def` must outlive the
294 // returned `NameRangeMap` objects.
295 typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
296     NameRangeMap;
297 Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
298                          NameRangeMap* inputs, NameRangeMap* outputs);
299 Status NameRangesForNode(const Node& node, const OpDef& op_def,
300                          NameRangeMap* inputs, NameRangeMap* outputs);
301 
302 // Adds default values to *node_def for unspecified attrs from op_def.
303 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
304 
305 // Validates the syntax of a NodeDef provided externally.
306 //
307 // The following is an EBNF-style syntax for NodeDef objects. Note that
308 // Node objects are actually specified as tensorflow::NodeDef protocol buffers,
309 // which contain many other fields that are not (currently) validated.
310 //
311 // Node         = NodeName, Inputs
312 // Inputs       = ( DataInput * ), ( ControlInput * )
313 // DataInput    = NodeName, ( ":", [1-9], [0-9] * ) ?
314 // ControlInput = "^", NodeName
315 // NodeName     = [A-Za-z0-9.], [A-Za-z0-9_./] *
316 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
317 
318 // Returns "status" with formatted NodeDef attached as additional text
319 // in the error message. If 'allow_multiple_formatted_node' is false and there
320 // is already a formatted NodeDef present in 'status', we simply attach the name
321 // of the NodeDef instead of the formatted string.
322 Status AttachDef(const Status& status, const NodeDef& node_def,
323                  bool allow_multiple_formatted_node = false);
324 Status AttachDef(const Status& status, const Node& node,
325                  bool allow_multiple_formatted_node = false);
326 
327 // Appends the given prefix and suffix to the original node name in order to
328 // make the name unique. If it's an "Enter" node, use the same way to reset
329 // attribute "frame_name".
330 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
331                                 NodeDef* node_def);
332 }  // namespace tensorflow
333 
334 #endif  // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
335