• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/stringpiece.h"
29 #include "tensorflow/core/lib/gtl/array_slice.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/platform/hash.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/stringpiece.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace tensorflow {
39 
40 class AttrSlice;
41 // We forward declare protos so that kernels don't need to depend on them
42 class OpDef;
43 class AttrValue;
44 class NameAttrList;
45 class TensorProto;
46 class TensorShapeProto;
47 
48 // Name of the attribute used to encode node colocation constraints.
49 //
50 // Nodes can be co-located on the same device. Desire for explicit co-location
51 // is described by list(string) attribute containing the name of colocation
52 // groups.
53 extern const char* const kColocationAttrName;
54 
55 // String prefix applied to the operation name for colocation constraints.
56 extern const char* const kColocationGroupPrefix;
57 
58 // Produce a human-readable version of a Node or NodeDef that is more concise
59 // than a text-format proto.
60 string SummarizeNodeDef(const NodeDef& node_def);
61 string SummarizeAttrs(const NodeDef& node_def);
62 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
63 
64 // Produces a formatted string pattern from the node which can uniquely identify
65 // this node upstream to produce an informative error message. The pattern
66 // followed is: {{node <node_name>}}
67 string FormatNodeDefForError(const NodeDef& node_def);
68 string FormatNodeDefForError(
69     StringPiece node_name, bool has_experimental_debug_info,
70     const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
71 
72 typedef protobuf::Map<string, AttrValue> AttrValueMap;
73 
74 // Adds an attr with name <name> and value <value> to *node_def.
75 // The type of the attr is based on the type of value.
76 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
77 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
78 void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
79 void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
80 void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
81 void AddNodeAttr(StringPiece name, int64 value, NodeDef* node_def);
82 void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
83 void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
84 void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
85 void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
86 void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
87                  NodeDef* node_def);
88 void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
89 void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
90 void AddNodeAttr(StringPiece name, const NameAttrList& value,
91                  NodeDef* node_def);
92 void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
93                  NodeDef* node_def);
94 void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
95                  NodeDef* node_def);
96 void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
97                  NodeDef* node_def);
98 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
99                  NodeDef* node_def);
100 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64> value,
101                  NodeDef* node_def);
102 void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
103                  NodeDef* node_def);
104 void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
105                  NodeDef* node_def);
106 void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
107                  NodeDef* node_def);
108 void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
109                  NodeDef* node_def);
110 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
111                  NodeDef* node_def);
112 void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
113                  NodeDef* node_def);
114 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
115                  NodeDef* node_def);
116 void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
117                  NodeDef* node_def);
118 void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
119                  NodeDef* node_def);
120 
121 // Version to workaround C++'s "perfect" forwarding not being able to
122 // forward {...} initialization.
123 template <class T>
AddNodeAttr(StringPiece name,std::initializer_list<T> value,NodeDef * node_def)124 void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
125                  NodeDef* node_def) {
126   AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
127 }
128 
129 // Adds an attr to an attr value map.
130 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
131 void AddAttr(StringPiece name, bool value, AttrValueMap* map);
132 
133 class AttrSlice {
134  public:
135   AttrSlice(const NodeDef& node_def);  // NOLINT(runtime/explicit)
136 
137   AttrSlice();  // Empty
138   explicit AttrSlice(const AttrValueMap* a);
139 
size()140   int size() const { return attrs_->size(); }
141 
142   // Returns the attr with attr_name if found.  Otherwise, returns
143   // nullptr.
144   const AttrValue* Find(StringPiece attr_name) const;
145 
146   // Returns the attr_value for attr_name if found. Otherwise, returns a
147   // NotFound status.
148   Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
149 
150   // Helper class to avoid allocations in EqualAttrs.
151   // TODO(irving): Will go away once NodeInfo is used.
152   struct Scratch {
153     string a;
154     string b;
155   };
156 
157   // Check if all attrs and attr values match.  Does not take defaults into
158   // account.
159   //
160   // TODO(irving): There is a bug in this routine inherited from its
161   // OptimizerCSE::EqualAttrs precedecessor.  The same tensor attr can be
162   // represented in more than one way as an AttrValue, since TensorProto is
163   // not 1-1.  This bug will go away once I replace everything with NodeInfo,
164   // which stores a Tensor object directly.  The Scratch object will also go
165   // away.
166   bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
167 
168   // If this AttrSlice has an attached NodeDef, summarize it.  This is for
169   // error messages only: we intentionally do not provide direct access to the
170   // NodeDef, since it is not always there.
171   string SummarizeNode() const;
172 
173   // Iteration over all attrs
begin()174   AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
end()175   AttrValueMap::const_iterator end() const { return attrs_->end(); }
176 
177   string DebugString() const;
178 
179  private:
180   const NodeDef* ndef_;
181   const AttrValueMap* attrs_;
182 };
183 
184 // Return true if the attr with the name attr_name is defined in node_def.
185 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
186 
187 // Look up the attr with name attr_name and set *value to its value.  If no
188 // attr with attr_name is found in node_def, or the attr does not have
189 // a matching type, a non-ok status will be returned.
190 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
191                    string* value);  // type: "string"
192 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
193                    tstring* value);  // type: "tstring"
194 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
195                    int64* value);  // type: "int"
196 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
197                    int32* value);  // type: "int"
198 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
199                    float* value);  // type: "float"
200 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
201                    bool* value);  // type: "bool"
202 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
203                    DataType* value);  // type: "type"
204 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
205                    TensorShapeProto* value);  // type: "shape"
206 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
207                    TensorShape* value);  // type: "shape"
208 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
209                    PartialTensorShape* value);  // type: "shape"
210 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
211                    Tensor* value);  // type: "tensor"
212 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
213                    std::vector<string>* value);  // type "list(string)"
214 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
215                    std::vector<tstring>* value);  // type "list(tstring)"
216 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
217                    std::vector<int64>* value);  // type "list(int)"
218 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
219                    std::vector<int32>* value);  // type "list(int)"
220 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
221                    std::vector<float>* value);  // type "list(float)"
222 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
223                    std::vector<bool>* value);  // type "list(bool)"
224 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
225                    std::vector<DataType>* value);  // type "list(type)"
226 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
227                    DataTypeVector* value);  // type "list(type)"
228 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
229                    std::vector<TensorShapeProto>* value);  // type "list(shape)"
230 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
231                    std::vector<TensorShape>* value);  // type "list(shape)"
232 Status GetNodeAttr(
233     const AttrSlice& attrs, StringPiece attr_name,
234     std::vector<PartialTensorShape>* value);  // type "list(shape)"
235 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
236                    std::vector<Tensor>* value);  // type: "list(tensor)"
237 
238 // This version avoids copying the TensorProto.
239 // REQUIRES: Must not use *value beyond the lifetime of node_def.
240 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
241                    const TensorProto** value);  // type: "tensor"
242 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
243                     const TensorProto** value);  // type: "tensor"
244 
245 // This version avoids copying the NameAttrList.
246 // REQUIRES: Must not use *value beyond the lifetime of node_def.
247 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
248                    const NameAttrList** value);  // type: "func"
249 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
250                     const NameAttrList** value);  // type: "func"
251 
252 // These versions copies the NameAttrList(s).
253 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
254                    NameAttrList* value);  // type: "func"
255 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
256                    std::vector<NameAttrList>* value);  // type: "list(func)"
257 
258 // Look up the attr with name attr_name and set *value to its value.  If no
259 // attr with attr_name is found in node_def, or the attr does not have
260 // a matching type, false is returned.
261 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
262                     string* value);  // type: "string"
263 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
264                     int64* value);  // type: "int"
265 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
266                     std::vector<int64>* value);  // type: "int"
267 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
268                     int32* value);  // type: "int"
269 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
270                     float* value);  // type: "float"
271 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
272                     bool* value);  // type: "bool"
273 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
274                     DataType* value);  // type: "type"
275 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
276                     TensorShape* value);  // type: "shape"
277 
278 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
279                     std::vector<string>* value);  // type: "list(string)"
280 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
281                     std::vector<tstring>* value);  // type: "list(tstring)"
282 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
283                     std::vector<int32>* value);  // type: "list(int)"
284 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
285                     std::vector<float>* value);  // type: "list(float)"
286 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
287                     std::vector<bool>* value);  // type: "list(bool)"
288 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
289                     std::vector<DataType>* value);  // type: "list(type)"
290 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
291                     std::vector<TensorShape> value);  // type: "shape"
292 
293 // Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute
294 // values.
295 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
296                     std::vector<const string*>* value);  // type: "list(string)"
297 bool TryGetNodeAttr(
298     const AttrSlice& attrs, StringPiece attr_name,
299     std::vector<const TensorShapeProto*>* value);  // type: "list(shape)"
300 
301 // Look up the attr with name attr_name and return a reference to its value.
302 // If no attr with attr_name is found in node_def, or the attr does not have
303 // a matching type, a reference to an empty string is returned.
304 // REQUIRES: Must not use the returned value beyond the lifetime of node_def.
305 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
306 
307 // Computes the input type for a specific node input.
308 // REQUIRES: ValidateOpDef(op_def).ok()
309 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
310                         int input_port, DataType* input_type);
311 // Computes the input types for a specific node.
312 // REQUIRES: ValidateOpDef(op_def).ok()
313 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
314                          DataTypeVector* inputs);
315 // Computes the output type for a specific node output.
316 // REQUIRES: ValidateOpDef(op_def).ok()
317 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
318                          int output_port, DataType* output_type);
319 // Computes the output types for a specific node.
320 // REQUIRES: ValidateOpDef(op_def).ok()
321 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
322                           DataTypeVector* outputs);
323 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
324                           DataTypeVector* outputs);
325 
326 // Computes the input and output types for a specific node.
327 // REQUIRES: ValidateOpDef(op_def).ok()
328 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
329                          DataTypeVector* inputs, DataTypeVector* outputs);
330 // Computes the number of outputs for a specific node.
331 // REQUIRES: ValidateOpDef(op_def).ok()
332 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
333                          int* num_outputs);
334 
335 // Validates that the NodeDef:
336 // * Defines all expected attrs from the OpDef.
337 // * All attrs satisfies constraints from the OpDef.
338 // * Has a signature matching SignatureForNode().
339 // etc.
340 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
341 
342 // Computes the mapping from input/output argument name to the
343 // corresponding input/output index range.  For example,
344 // input "foo" corresponds to input indices
345 //   [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
346 // NOTE(mrry): To reduce allocations when the map is used and save
347 // space, the returned `NameRangeMap` objects borrow the input/output
348 // argument names from `op_def`. The `op_def` must outlive the
349 // returned `NameRangeMap` objects.
350 typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
351     NameRangeMap;
352 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
353                          NameRangeMap* inputs, NameRangeMap* outputs);
354 // Adds default values to *node_def for unspecified attrs from op_def.
355 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
356 
357 // Validates the syntax of a NodeDef provided externally.
358 //
359 // The following is an EBNF-style syntax for NodeDef objects. Note that
360 // Node objects are actually specified as tensorflow::NodeDef protocol buffers,
361 // which contain many other fields that are not (currently) validated.
362 //
363 // Node         = NodeName, Inputs
364 // Inputs       = ( DataInput * ), ( ControlInput * )
365 // DataInput    = NodeName, ( ":", [1-9], [0-9] * ) ?
366 // ControlInput = "^", NodeName
367 // NodeName     = [A-Za-z0-9.], [A-Za-z0-9_./] *
368 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
369 
370 // Returns "status" with formatted NodeDef attached as additional text
371 // in the error message. If 'allow_multiple_formatted_node' is false and there
372 // is already a formatted NodeDef present in 'status', we simply attach the name
373 // of the NodeDef instead of the formatted string.
374 Status AttachDef(const Status& status, const NodeDef& node_def,
375                  bool allow_multiple_formatted_node = false);
376 // Appends the given prefix and suffix to the original node name in order to
377 // make the name unique. If it's an "Enter" node and uniquify_frame_name is
378 // true, use the same way to reset attribute "frame_name".
379 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
380                                 NodeDef* node_def,
381                                 bool uniquify_frame_name = true);
382 }  // namespace tensorflow
383 
384 #endif  // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
385