• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
18 
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/gtl/array_slice.h"
31 #include "tensorflow/core/lib/gtl/flatmap.h"
32 #include "tensorflow/core/lib/hash/hash.h"
33 #include "tensorflow/core/platform/hash.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/status.h"
36 #include "tensorflow/core/platform/stringpiece.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/padding.h"
39 
40 namespace tensorflow {
41 
42 class AttrSlice;
43 // We forward declare protos so that kernels don't need to depend on them
44 class OpDef;
45 class AttrValue;
46 class NameAttrList;
47 class TensorProto;
48 class TensorShapeProto;
49 
50 // Name of the attribute used to encode node colocation constraints.
51 //
52 // Nodes can be co-located on the same device. Desire for explicit co-location
53 // is described by list(string) attribute containing the name of colocation
54 // groups.
55 extern const char* const kColocationAttrName;
56 
57 // String prefix applied to the operation name for colocation constraints.
58 extern const char* const kColocationGroupPrefix;
59 
60 // Produce a human-readable version of a Node or NodeDef that is more concise
61 // than a text-format proto.
62 //
63 // The parameter `max_inputs_in_summary` specifies how many inputs at most to
64 // serialize in the output (in order not to get a string which is overly large).
65 // The value `-1` specifies that all inputs will be shown.
66 std::string SummarizeNodeDef(const NodeDef& node_def,
67                              int max_inputs_in_summary = -1);
68 std::string SummarizeAttrs(const NodeDef& node_def);
69 std::string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
70 
71 // Produces a formatted string pattern from the node which can uniquely identify
72 // this node upstream to produce an informative error message. The pattern
73 // followed is: {{node <node_name>}}
74 std::string FormatNodeDefForError(const NodeDef& node_def);
75 std::string FormatNodeDefForError(
76     StringPiece node_name, bool has_experimental_debug_info,
77     const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
78 
79 typedef protobuf::Map<string, AttrValue> AttrValueMap;
80 
81 // Adds an attr with name <name> and value <value> to *node_def.
82 // The type of the attr is based on the type of value.
83 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
84 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
85 void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
86 void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
87 void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
88 void AddNodeAttr(StringPiece name, int64 value, NodeDef* node_def);
89 void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
90 void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
91 void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
92 void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
93 void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
94                  NodeDef* node_def);
95 void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
96 void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
97 void AddNodeAttr(StringPiece name, const NameAttrList& value,
98                  NodeDef* node_def);
99 void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
100                  NodeDef* node_def);
101 void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
102                  NodeDef* node_def);
103 void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
104                  NodeDef* node_def);
105 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
106                  NodeDef* node_def);
107 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64> value,
108                  NodeDef* node_def);
109 void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
110                  NodeDef* node_def);
111 void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
112                  NodeDef* node_def);
113 void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
114                  NodeDef* node_def);
115 void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
116                  NodeDef* node_def);
117 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
118                  NodeDef* node_def);
119 void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
120                  NodeDef* node_def);
121 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
122                  NodeDef* node_def);
123 void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
124                  NodeDef* node_def);
125 void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
126                  NodeDef* node_def);
127 
128 // Version to workaround C++'s "perfect" forwarding not being able to
129 // forward {...} initialization.
130 template <class T>
AddNodeAttr(StringPiece name,std::initializer_list<T> value,NodeDef * node_def)131 void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
132                  NodeDef* node_def) {
133   AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
134 }
135 
136 // Adds an attr to an attr value map.
137 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
138 void AddAttr(StringPiece name, bool value, AttrValueMap* map);
139 
140 class AttrSlice {
141  public:
142   AttrSlice(const NodeDef& node_def);  // NOLINT(runtime/explicit)
143 
144   AttrSlice();  // Empty
145   explicit AttrSlice(const AttrValueMap* a);
146 
size()147   int size() const { return attrs_->size(); }
148 
149   // Returns the attr with attr_name if found.  Otherwise, returns
150   // nullptr.
151   const AttrValue* Find(StringPiece attr_name) const;
152   const AttrValue* FindByString(const std::string& attr_name) const;
153 
154   // Returns the attr_value for attr_name if found. Otherwise, returns a
155   // NotFound status.
156   Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
157 
158   // Helper class to avoid allocations in EqualAttrs.
159   // TODO(irving): Will go away once NodeInfo is used.
160   struct Scratch {
161     std::string a;
162     std::string b;
163   };
164 
165   // Check if all attrs and attr values match.  Does not take defaults into
166   // account.
167   //
168   // TODO(irving): There is a bug in this routine inherited from its
169   // OptimizerCSE::EqualAttrs predecessor.  The same tensor attr can be
170   // represented in more than one way as an AttrValue, since TensorProto is
171   // not 1-1.  This bug will go away once I replace everything with NodeInfo,
172   // which stores a Tensor object directly.  The Scratch object will also go
173   // away.
174   bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
175 
176   // If this AttrSlice has an attached NodeDef, summarize it.  This is for
177   // error messages only: we intentionally do not provide direct access to the
178   // NodeDef, since it is not always there.
179   std::string SummarizeNode() const;
180 
181   // Iteration over all attrs
begin()182   AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
end()183   AttrValueMap::const_iterator end() const { return attrs_->end(); }
184 
185   std::string DebugString() const;
186 
187  private:
188   const NodeDef* ndef_;
189   const AttrValueMap* attrs_;
190 };
191 
192 // Return true if the attr with the name attr_name is defined in node_def.
193 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
194 
195 // Look up the attr with name attr_name and set *value to its value.  If no
196 // attr with attr_name is found in node_def, or the attr does not have
197 // a matching type, a non-ok status will be returned.
198 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
199                    std::string* value);  // type: "string"
200 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
201                    tstring* value);  // type: "tstring"
202 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
203                    int64* value);  // type: "int"
204 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
205                    int32* value);  // type: "int"
206 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
207                    float* value);  // type: "float"
208 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
209                    bool* value);  // type: "bool"
210 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
211                    DataType* value);  // type: "type"
212 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
213                    TensorShapeProto* value);  // type: "shape"
214 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
215                    TensorShape* value);  // type: "shape"
216 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
217                    PartialTensorShape* value);  // type: "shape"
218 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
219                    Tensor* value);  // type: "tensor"
220 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
221                    std::vector<string>* value);  // type "list(string)"
222 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
223                    std::vector<tstring>* value);  // type "list(tstring)"
224 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
225                    std::vector<int64>* value);  // type "list(int)"
226 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
227                    std::vector<int32>* value);  // type "list(int)"
228 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
229                    std::vector<float>* value);  // type "list(float)"
230 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
231                    std::vector<bool>* value);  // type "list(bool)"
232 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
233                    std::vector<DataType>* value);  // type "list(type)"
234 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
235                    DataTypeVector* value);  // type "list(type)"
236 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
237                    std::vector<TensorShapeProto>* value);  // type "list(shape)"
238 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
239                    std::vector<TensorShape>* value);  // type "list(shape)"
240 Status GetNodeAttr(
241     const AttrSlice& attrs, StringPiece attr_name,
242     std::vector<PartialTensorShape>* value);  // type "list(shape)"
243 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
244                    std::vector<Tensor>* value);  // type: "list(tensor)"
245 
246 // This version avoids copying the TensorProto.
247 // REQUIRES: Must not use *value beyond the lifetime of node_def.
248 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
249                    const TensorProto** value);  // type: "tensor"
250 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
251                     const TensorProto** value);  // type: "tensor"
252 
253 // This version avoids copying the NameAttrList.
254 // REQUIRES: Must not use *value beyond the lifetime of node_def.
255 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
256                    const NameAttrList** value);  // type: "func"
257 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
258                     const NameAttrList** value);  // type: "func"
259 
260 // These versions copies the NameAttrList(s).
261 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
262                    NameAttrList* value);  // type: "func"
263 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
264                    std::vector<NameAttrList>* value);  // type: "list(func)"
265 
266 // Look up the attr with name attr_name and set *value to its value.  If no
267 // attr with attr_name is found in node_def, or the attr does not have
268 // a matching type, false is returned.
269 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
270                     std::string* value);  // type: "string"
271 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
272                     int64* value);  // type: "int"
273 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
274                     std::vector<int64>* value);  // type: "int"
275 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
276                     int32* value);  // type: "int"
277 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
278                     float* value);  // type: "float"
279 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
280                     bool* value);  // type: "bool"
281 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
282                     DataType* value);  // type: "type"
283 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
284                     TensorShape* value);  // type: "shape"
285 
286 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
287                     std::vector<string>* value);  // type: "list(string)"
288 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
289                     std::vector<tstring>* value);  // type: "list(tstring)"
290 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
291                     std::vector<int32>* value);  // type: "list(int)"
292 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
293                     std::vector<float>* value);  // type: "list(float)"
294 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
295                     std::vector<bool>* value);  // type: "list(bool)"
296 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
297                     std::vector<DataType>* value);  // type: "list(type)"
298 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
299                     std::vector<TensorShape> value);  // type: "shape"
300 
301 // Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute
302 // values.
303 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
304                     std::vector<const string*>* value);  // type: "list(string)"
305 bool TryGetNodeAttr(
306     const AttrSlice& attrs, StringPiece attr_name,
307     std::vector<const TensorShapeProto*>* value);  // type: "list(shape)"
308 
309 // Look up the attr with name attr_name and return a reference to its value.
310 // If no attr with attr_name is found in node_def, or the attr does not have
311 // a matching type, a reference to an empty string is returned.
312 // REQUIRES: Must not use the returned value beyond the lifetime of node_def.
313 const std::string& GetNodeAttrString(const AttrSlice& attrs,
314                                      StringPiece attr_name);
315 
316 // Specialization to parse an attribute directly into a Padding enum.
317 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
318                    Padding* value);
319 
320 // Computes the input type for a specific node input.
321 // REQUIRES: ValidateOpDef(op_def).ok()
322 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
323                         int input_port, DataType* input_type);
324 // Computes the input types for a specific node.
325 // REQUIRES: ValidateOpDef(op_def).ok()
326 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
327                          DataTypeVector* inputs);
328 // Computes the output type for a specific node output.
329 // REQUIRES: ValidateOpDef(op_def).ok()
330 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
331                          int output_port, DataType* output_type);
332 // Computes the output types for a specific node.
333 // REQUIRES: ValidateOpDef(op_def).ok()
334 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
335                           DataTypeVector* outputs);
336 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
337                           DataTypeVector* outputs);
338 
339 // Computes the input and output types for a specific node.
340 // REQUIRES: ValidateOpDef(op_def).ok()
341 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
342                          DataTypeVector* inputs, DataTypeVector* outputs);
343 // Computes the number of outputs for a specific node.
344 // REQUIRES: ValidateOpDef(op_def).ok()
345 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
346                          int* num_outputs);
347 
348 // Validates that the NodeDef:
349 // * Defines all expected attrs from the OpDef.
350 // * All attrs satisfies constraints from the OpDef.
351 // * Has a signature matching SignatureForNode().
352 // etc.
353 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
354 
355 // Computes the mapping from input/output argument name to the
356 // corresponding input/output index range.  For example,
357 // input "foo" corresponds to input indices
358 //   [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
359 // NOTE(mrry): To reduce allocations when the map is used and save
360 // space, the returned `NameRangeMap` objects borrow the input/output
361 // argument names from `op_def`. The `op_def` must outlive the
362 // returned `NameRangeMap` objects.
363 typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
364     NameRangeMap;
365 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
366                          NameRangeMap* inputs, NameRangeMap* outputs);
367 // Adds default values to *node_def for unspecified attrs from op_def.
368 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
369 
370 // Validates the syntax of a NodeDef provided externally.
371 //
372 // The following is an EBNF-style syntax for NodeDef objects. Note that
373 // Node objects are actually specified as tensorflow::NodeDef protocol buffers,
374 // which contain many other fields that are not (currently) validated.
375 //
376 // Node         = NodeName, Inputs
377 // Inputs       = ( DataInput * ), ( ControlInput * )
378 // DataInput    = NodeName, ( ":", [1-9], [0-9] * ) ?
379 // ControlInput = "^", NodeName
380 // NodeName     = [A-Za-z0-9.], [A-Za-z0-9_./] *
381 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
382 
383 // Returns "status" with formatted NodeDef attached as additional text
384 // in the error message. If 'allow_multiple_formatted_node' is false and there
385 // is already a formatted NodeDef present in 'status', we simply attach the name
386 // of the NodeDef instead of the formatted string.
387 Status AttachDef(const Status& status, const NodeDef& node_def,
388                  bool allow_multiple_formatted_node = false);
389 // Appends the given prefix and suffix to the original node name in order to
390 // make the name unique. If it's an "Enter" node and uniquify_frame_name is
391 // true, use the same way to reset attribute "frame_name".
392 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
393                                 NodeDef* node_def,
394                                 bool uniquify_frame_name = true);
395 
396 // Appends the given prefix to the colocation group name if the name exists
397 // in `to_match`.
398 Status MaybeAddPrefixToColocationConstraints(
399     const std::unordered_set<string>& match, StringPiece prefix,
400     NodeDef* node_def);
401 
402 }  // namespace tensorflow
403 
404 #endif  // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
405