• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
18 
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/gtl/array_slice.h"
31 #include "tensorflow/core/lib/gtl/flatmap.h"
32 #include "tensorflow/core/lib/hash/hash.h"
33 #include "tensorflow/core/platform/hash.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/status.h"
36 #include "tensorflow/core/platform/stringpiece.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/padding.h"
39 
40 namespace tensorflow {
41 
42 class AttrSlice;
43 // We forward declare protos so that kernels don't need to depend on them
44 class OpDef;
45 class AttrValue;
46 class NameAttrList;
47 class TensorProto;
48 class TensorShapeProto;
49 
50 // Name of the attribute used to encode node colocation constraints.
51 //
52 // Nodes can be co-located on the same device. Desire for explicit co-location
53 // is described by list(string) attribute containing the name of colocation
54 // groups.
55 extern const char* const kColocationAttrName;
56 
57 // String prefix applied to the operation name for colocation constraints.
58 extern const char* const kColocationGroupPrefix;
59 
60 // Constants for host CPU staging op for TPUExecute.
61 extern const char* const kTpuExecuteStagingOp;
62 extern const char* const kTpuExecuteStagingNodeName;
63 
64 // Produce a human-readable version of a Node or NodeDef that is more concise
65 // than a text-format proto.
66 //
67 // The parameter `max_inputs_in_summary` specifies how many inputs at most to
68 // serialize in the output (in order not to get a string which is overly large).
69 // The value `-1` specifies that all inputs will be shown.
70 std::string SummarizeNodeDef(const NodeDef& node_def,
71                              int max_inputs_in_summary = -1);
72 std::string SummarizeAttrs(const NodeDef& node_def);
73 std::string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
74 
75 // Produces a formatted string pattern from the node which can uniquely identify
76 // this node upstream to produce an informative error message. The pattern
77 // followed is: {{node <node_name>}}
78 std::string FormatNodeDefForError(const NodeDef& node_def);
79 std::string FormatNodeDefForError(
80     StringPiece node_name, bool has_experimental_debug_info,
81     const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
82 
83 typedef protobuf::Map<string, AttrValue> AttrValueMap;
84 
85 // Adds an attr with name <name> and value <value> to *node_def.
86 // The type of the attr is based on the type of value.
87 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
88 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
89 void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
90 void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
91 void AddNodeAttr(StringPiece name, int32_t value, NodeDef* node_def);
92 void AddNodeAttr(StringPiece name, int64_t value, NodeDef* node_def);
93 void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
94 void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
95 void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
96 void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
97 void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
98                  NodeDef* node_def);
99 void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
100 void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
101 void AddNodeAttr(StringPiece name, const NameAttrList& value,
102                  NodeDef* node_def);
103 void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
104                  NodeDef* node_def);
105 void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
106                  NodeDef* node_def);
107 void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
108                  NodeDef* node_def);
109 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
110                  NodeDef* node_def);
111 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64> value,
112                  NodeDef* node_def);
113 void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
114                  NodeDef* node_def);
115 void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
116                  NodeDef* node_def);
117 void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
118                  NodeDef* node_def);
119 void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
120                  NodeDef* node_def);
121 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
122                  NodeDef* node_def);
123 void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
124                  NodeDef* node_def);
125 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
126                  NodeDef* node_def);
127 void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
128                  NodeDef* node_def);
129 void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
130                  NodeDef* node_def);
131 
132 // Version to workaround C++'s "perfect" forwarding not being able to
133 // forward {...} initialization.
134 template <class T>
AddNodeAttr(StringPiece name,std::initializer_list<T> value,NodeDef * node_def)135 void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
136                  NodeDef* node_def) {
137   AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
138 }
139 
140 // Adds an attr to an attr value map.
141 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
142 void AddAttr(StringPiece name, bool value, AttrValueMap* map);
143 
144 class AttrSlice {
145  public:
146   AttrSlice(const NodeDef& node_def);  // NOLINT(runtime/explicit)
147 
148   AttrSlice();  // Empty
149   explicit AttrSlice(const AttrValueMap* a);
150 
size()151   int size() const { return attrs_->size(); }
152 
153   // Returns the attr with attr_name if found.  Otherwise, returns
154   // nullptr.
155   const AttrValue* Find(StringPiece attr_name) const;
156   const AttrValue* FindByString(const std::string& attr_name) const;
157 
158   // Returns the attr_value for attr_name if found. Otherwise, returns a
159   // NotFound status.
160   Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
161 
162   // Helper class to avoid allocations in EqualAttrs.
163   // TODO(irving): Will go away once NodeInfo is used.
164   struct Scratch {
165     std::string a;
166     std::string b;
167   };
168 
169   // Check if all attrs and attr values match.  Does not take defaults into
170   // account.
171   //
172   // TODO(irving): There is a bug in this routine inherited from its
173   // OptimizerCSE::EqualAttrs predecessor.  The same tensor attr can be
174   // represented in more than one way as an AttrValue, since TensorProto is
175   // not 1-1.  This bug will go away once I replace everything with NodeInfo,
176   // which stores a Tensor object directly.  The Scratch object will also go
177   // away.
178   bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
179 
180   // If this AttrSlice has an attached NodeDef, summarize it.  This is for
181   // error messages only: we intentionally do not provide direct access to the
182   // NodeDef, since it is not always there.
183   std::string SummarizeNode() const;
184 
185   // Iteration over all attrs
begin()186   AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
end()187   AttrValueMap::const_iterator end() const { return attrs_->end(); }
188 
189   std::string DebugString() const;
190 
191  private:
192   const NodeDef* ndef_;
193   const AttrValueMap* attrs_;
194 };
195 
196 // Return true if the attr with the name attr_name is defined in node_def.
197 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
198 
199 // Look up the attr with name attr_name and set *value to its value.  If no
200 // attr with attr_name is found in node_def, or the attr does not have
201 // a matching type, a non-ok status will be returned.
202 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
203                    std::string* value);  // type: "string"
204 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
205                    tstring* value);  // type: "tstring"
206 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
207                    int64* value);  // type: "int"
208 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
209                    int32* value);  // type: "int"
210 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
211                    float* value);  // type: "float"
212 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
213                    bool* value);  // type: "bool"
214 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
215                    DataType* value);  // type: "type"
216 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
217                    TensorShapeProto* value);  // type: "shape"
218 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
219                    TensorShape* value);  // type: "shape"
220 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
221                    PartialTensorShape* value);  // type: "shape"
222 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
223                    Tensor* value);  // type: "tensor"
224 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
225                    std::vector<string>* value);  // type "list(string)"
226 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
227                    std::vector<tstring>* value);  // type "list(tstring)"
228 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
229                    std::vector<int64>* value);  // type "list(int)"
230 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
231                    std::vector<int32>* value);  // type "list(int)"
232 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
233                    std::vector<float>* value);  // type "list(float)"
234 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
235                    std::vector<bool>* value);  // type "list(bool)"
236 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
237                    std::vector<DataType>* value);  // type "list(type)"
238 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
239                    DataTypeVector* value);  // type "list(type)"
240 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
241                    std::vector<TensorShapeProto>* value);  // type "list(shape)"
242 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
243                    std::vector<TensorShape>* value);  // type "list(shape)"
244 Status GetNodeAttr(
245     const AttrSlice& attrs, StringPiece attr_name,
246     std::vector<PartialTensorShape>* value);  // type "list(shape)"
247 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
248                    std::vector<Tensor>* value);  // type: "list(tensor)"
249 
250 // This version avoids copying the TensorProto.
251 // REQUIRES: Must not use *value beyond the lifetime of node_def.
252 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
253                    const TensorProto** value);  // type: "tensor"
254 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
255                     const TensorProto** value);  // type: "tensor"
256 
257 // This version avoids copying the NameAttrList.
258 // REQUIRES: Must not use *value beyond the lifetime of node_def.
259 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
260                    const NameAttrList** value);  // type: "func"
261 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
262                     const NameAttrList** value);  // type: "func"
263 
264 // These versions copies the NameAttrList(s).
265 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
266                    NameAttrList* value);  // type: "func"
267 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
268                    std::vector<NameAttrList>* value);  // type: "list(func)"
269 
270 // Look up the attr with name attr_name and set *value to its value.  If no
271 // attr with attr_name is found in node_def, or the attr does not have
272 // a matching type, false is returned.
273 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
274                     std::string* value);  // type: "string"
275 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
276                     int64* value);  // type: "int"
277 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
278                     std::vector<int64>* value);  // type: "int"
279 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
280                     int32* value);  // type: "int"
281 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
282                     float* value);  // type: "float"
283 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
284                     bool* value);  // type: "bool"
285 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
286                     DataType* value);  // type: "type"
287 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
288                     TensorShape* value);  // type: "shape"
289 
290 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
291                     std::vector<string>* value);  // type: "list(string)"
292 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
293                     std::vector<tstring>* value);  // type: "list(tstring)"
294 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
295                     std::vector<int32>* value);  // type: "list(int)"
296 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
297                     std::vector<float>* value);  // type: "list(float)"
298 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
299                     std::vector<bool>* value);  // type: "list(bool)"
300 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
301                     std::vector<DataType>* value);  // type: "list(type)"
302 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
303                     std::vector<TensorShape> value);  // type: "shape"
304 
305 // Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute
306 // values.
307 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
308                     std::vector<const string*>* value);  // type: "list(string)"
309 bool TryGetNodeAttr(
310     const AttrSlice& attrs, StringPiece attr_name,
311     std::vector<const TensorShapeProto*>* value);  // type: "list(shape)"
312 
313 // Look up the attr with name attr_name and return a reference to its value.
314 // If no attr with attr_name is found in node_def, or the attr does not have
315 // a matching type, a reference to an empty string is returned.
316 // REQUIRES: Must not use the returned value beyond the lifetime of node_def.
317 const std::string& GetNodeAttrString(const AttrSlice& attrs,
318                                      StringPiece attr_name);
319 
320 // Specialization to parse an attribute directly into a Padding enum.
321 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
322                    Padding* value);
323 
324 // Computes the input type for a specific node input.
325 // REQUIRES: ValidateOpDef(op_def).ok()
326 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
327                         int input_port, DataType* input_type);
328 // Computes the input types for a specific node.
329 // REQUIRES: ValidateOpDef(op_def).ok()
330 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
331                          DataTypeVector* inputs);
332 // Computes the output type for a specific node output.
333 // REQUIRES: ValidateOpDef(op_def).ok()
334 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
335                          int output_port, DataType* output_type);
336 // Computes the output types for a specific node.
337 // REQUIRES: ValidateOpDef(op_def).ok()
338 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
339                           DataTypeVector* outputs);
340 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
341                           DataTypeVector* outputs);
342 
343 // Computes the input and output types for a specific node.
344 // REQUIRES: ValidateOpDef(op_def).ok()
345 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
346                          DataTypeVector* inputs, DataTypeVector* outputs);
347 // Computes the number of outputs for a specific node.
348 // REQUIRES: ValidateOpDef(op_def).ok()
349 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
350                          int* num_outputs);
351 
352 // Validates that the NodeDef:
353 // * Defines all expected attrs from the OpDef.
354 // * All attrs satisfies constraints from the OpDef.
355 // * Has a signature matching SignatureForNode().
356 // etc.
357 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
358 
359 // Computes the mapping from input/output argument name to the
360 // corresponding input/output index range.  For example,
361 // input "foo" corresponds to input indices
362 //   [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
363 // NOTE(mrry): To reduce allocations when the map is used and save
364 // space, the returned `NameRangeMap` objects borrow the input/output
365 // argument names from `op_def`. The `op_def` must outlive the
366 // returned `NameRangeMap` objects.
367 typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
368     NameRangeMap;
369 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
370                          NameRangeMap* inputs, NameRangeMap* outputs);
371 // Adds default values to *node_def for unspecified attrs from op_def.
372 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
373 
374 // Validates the syntax of a NodeDef provided externally.
375 //
376 // The following is an EBNF-style syntax for NodeDef objects. Note that
377 // Node objects are actually specified as tensorflow::NodeDef protocol buffers,
378 // which contain many other fields that are not (currently) validated.
379 //
380 // Node         = NodeName, Inputs
381 // Inputs       = ( DataInput * ), ( ControlInput * )
382 // DataInput    = NodeName, ( ":", [1-9], [0-9] * ) ?
383 // ControlInput = "^", NodeName
384 // NodeName     = [A-Za-z0-9.], [A-Za-z0-9_./] *
385 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
386 
387 // Returns "status" with formatted NodeDef attached as additional text
388 // in the error message. If 'allow_multiple_formatted_node' is false and there
389 // is already a formatted NodeDef present in 'status', we simply attach the name
390 // of the NodeDef instead of the formatted string.
391 Status AttachDef(const Status& status, const NodeDef& node_def,
392                  bool allow_multiple_formatted_node = false);
393 // Appends the given prefix and suffix to the original node name in order to
394 // make the name unique. If it's an "Enter" node and uniquify_frame_name is
395 // true, use the same way to reset attribute "frame_name".
396 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
397                                 NodeDef* node_def,
398                                 bool uniquify_frame_name = true);
399 
400 // Appends the given prefix to the colocation group name if the name exists
401 // in `to_match`.
402 Status MaybeAddPrefixToColocationConstraints(
403     const std::unordered_set<string>& match, StringPiece prefix,
404     NodeDef* node_def);
405 
406 // Updates the colocation constraint name with the one provided in the map (if
407 // it exists in the map) for node_def.
408 Status MaybeUpdateColocationConstraintsWithMap(
409     const std::map<absl::string_view, absl::string_view>& node_name_map,
410     NodeDef* node_def);
411 
412 }  // namespace tensorflow
413 
414 #endif  // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
415