• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
18 
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/gtl/array_slice.h"
32 #include "tensorflow/core/lib/gtl/flatmap.h"
33 #include "tensorflow/core/lib/hash/hash.h"
34 #include "tensorflow/core/platform/hash.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/stringpiece.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/util/padding.h"
40 
41 namespace tensorflow {
42 
43 class AttrSlice;
44 // We forward declare protos so that kernels don't need to depend on them
45 class OpDef;
46 class AttrValue;
47 class NameAttrList;
48 class TensorProto;
49 class TensorShapeProto;
50 
51 // Name of the attribute used to encode node colocation constraints.
52 //
53 // Nodes can be co-located on the same device. Desire for explicit co-location
54 // is described by list(string) attribute containing the name of colocation
55 // groups.
56 extern const char* const kColocationAttrName;
57 
58 // String prefix applied to the operation name for colocation constraints.
59 extern const char* const kColocationGroupPrefix;
60 
61 // Constants for host CPU staging op for TPUExecute.
62 extern const char* const kTpuExecuteStagingOp;
63 extern const char* const kTpuExecuteStagingNodeName;
64 
65 // Produce a human-readable version of a Node or NodeDef that is more concise
66 // than a text-format proto.
67 //
68 // The parameter `max_inputs_in_summary` specifies how many inputs at most to
69 // serialize in the output (in order not to get a string which is overly large).
70 // The value `-1` specifies that all inputs will be shown.
71 std::string SummarizeNodeDef(const NodeDef& node_def,
72                              int max_inputs_in_summary = -1);
73 std::string SummarizeAttrs(const NodeDef& node_def);
74 std::string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
75 
76 // Produces a formatted string pattern from the node which can uniquely identify
77 // this node upstream to produce an informative error message. The pattern
78 // followed is: {{node <node_name>}}
79 std::string FormatNodeDefForError(const NodeDef& node_def);
80 std::string FormatNodeDefForError(
81     StringPiece node_name, bool has_experimental_debug_info,
82     const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
83 
84 typedef protobuf::Map<string, AttrValue> AttrValueMap;
85 
86 // Adds an attr with name <name> and value <value> to *node_def.
87 // The type of the attr is based on the type of value.
88 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
89 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
90 void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
91 void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
92 void AddNodeAttr(StringPiece name, int32_t value, NodeDef* node_def);
93 void AddNodeAttr(StringPiece name, int64_t value, NodeDef* node_def);
94 void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
95 void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
96 void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
97 void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
98 void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
99                  NodeDef* node_def);
100 void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
101 void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
102 void AddNodeAttr(StringPiece name, const NameAttrList& value,
103                  NodeDef* node_def);
104 void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
105                  NodeDef* node_def);
106 void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
107                  NodeDef* node_def);
108 void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
109                  NodeDef* node_def);
110 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
111                  NodeDef* node_def);
112 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64_t> value,
113                  NodeDef* node_def);
114 void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
115                  NodeDef* node_def);
116 void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
117                  NodeDef* node_def);
118 void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
119                  NodeDef* node_def);
120 void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
121                  NodeDef* node_def);
122 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
123                  NodeDef* node_def);
124 void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
125                  NodeDef* node_def);
126 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
127                  NodeDef* node_def);
128 void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
129                  NodeDef* node_def);
130 void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
131                  NodeDef* node_def);
132 
133 // Version to workaround C++'s "perfect" forwarding not being able to
134 // forward {...} initialization.
135 template <class T>
AddNodeAttr(StringPiece name,std::initializer_list<T> value,NodeDef * node_def)136 void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
137                  NodeDef* node_def) {
138   AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
139 }
140 
141 // Adds an attr to an attr value map.
142 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
143 void AddAttr(StringPiece name, bool value, AttrValueMap* map);
144 
145 class AttrSlice {
146  public:
147   AttrSlice(const NodeDef& node_def);  // NOLINT(runtime/explicit)
148 
149   AttrSlice();  // Empty
150   explicit AttrSlice(const AttrValueMap* a);
151 
size()152   int size() const { return attrs()->size(); }
153 
154   // Returns the attr with attr_name if found.  Otherwise, returns
155   // nullptr.
156   const AttrValue* Find(StringPiece attr_name) const;
157   const AttrValue* FindByString(const std::string& attr_name) const;
158 
159   // Returns the attr_value for attr_name if found. Otherwise, returns a
160   // NotFound status.
161   Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
162   Status FindByString(const std::string& attr_name,
163                       const AttrValue** attr_value) const;
164 
165   // Helper class to avoid allocations in EqualAttrs.
166   // TODO(irving): Will go away once NodeInfo is used.
167   struct Scratch {
168     std::string a;
169     std::string b;
170   };
171 
172   // Check if all attrs and attr values match.  Does not take defaults into
173   // account.
174   //
175   // TODO(irving): There is a bug in this routine inherited from its
176   // OptimizerCSE::EqualAttrs predecessor.  The same tensor attr can be
177   // represented in more than one way as an AttrValue, since TensorProto is
178   // not 1-1.  This bug will go away once I replace everything with NodeInfo,
179   // which stores a Tensor object directly.  The Scratch object will also go
180   // away.
181   bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
182 
183   // If this AttrSlice has an attached NodeDef, summarize it.  This is for
184   // error messages only: we intentionally do not provide direct access to the
185   // NodeDef, since it is not always there.
186   std::string SummarizeNode() const;
187 
188   // Iteration over all attrs
begin()189   AttrValueMap::const_iterator begin() const { return attrs()->begin(); }
end()190   AttrValueMap::const_iterator end() const { return attrs()->end(); }
191 
192   std::string DebugString() const;
193 
194  private:
attrs()195   const AttrValueMap* attrs() const {
196     return ndef_ != nullptr ? &ndef_->attr() : attrs_;
197   }
198 
199   Status CheckFind(StringPiece attr_name, const AttrValue* attr_value) const;
200 
201   const NodeDef* ndef_;
202   const AttrValueMap* attrs_;
203 };
204 
205 // Return true if the attr with the name attr_name is defined in node_def.
206 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
207 
208 // Look up the attr with name attr_name and set *value to its value.  If no
209 // attr with attr_name is found in node_def, or the attr does not have
210 // a matching type, a non-ok status will be returned.
211 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
212                    std::string* value);  // type: "string"
213 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
214                    tstring* value);  // type: "tstring"
215 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
216                    int64_t* value);  // type: "int"
217 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
218                    int32* value);  // type: "int"
219 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
220                    float* value);  // type: "float"
221 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
222                    bool* value);  // type: "bool"
223 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
224                    DataType* value);  // type: "type"
225 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
226                    TensorShapeProto* value);  // type: "shape"
227 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
228                    TensorShape* value);  // type: "shape"
229 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
230                    PartialTensorShape* value);  // type: "shape"
231 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
232                    Tensor* value);  // type: "tensor"
233 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
234                    std::vector<string>* value);  // type "list(string)"
235 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
236                    std::vector<tstring>* value);  // type "list(tstring)"
237 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
238                    std::vector<int64_t>* value);  // type "list(int)"
239 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
240                    std::vector<int32>* value);  // type "list(int)"
241 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
242                    std::vector<float>* value);  // type "list(float)"
243 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
244                    std::vector<bool>* value);  // type "list(bool)"
245 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
246                    std::vector<DataType>* value);  // type "list(type)"
247 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
248                    DataTypeVector* value);  // type "list(type)"
249 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
250                    std::vector<TensorShapeProto>* value);  // type "list(shape)"
251 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
252                    std::vector<TensorShape>* value);  // type "list(shape)"
253 Status GetNodeAttr(
254     const AttrSlice& attrs, StringPiece attr_name,
255     std::vector<PartialTensorShape>* value);  // type "list(shape)"
256 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
257                    std::vector<Tensor>* value);  // type: "list(tensor)"
258 
259 template <typename T>
GetNodeAttr(const NodeDef & ndef,absl::string_view attr_name)260 StatusOr<T> GetNodeAttr(const NodeDef& ndef, absl::string_view attr_name) {
261   T val;
262   TF_RETURN_IF_ERROR(GetNodeAttr(ndef, attr_name, &val));
263   return val;
264 }
265 
266 // This version avoids copying the TensorProto.
267 // REQUIRES: Must not use *value beyond the lifetime of node_def.
268 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
269                    const TensorProto** value);  // type: "tensor"
270 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
271                     const TensorProto** value);  // type: "tensor"
272 
273 // This version avoids copying the NameAttrList.
274 // REQUIRES: Must not use *value beyond the lifetime of node_def.
275 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
276                    const NameAttrList** value);  // type: "func"
277 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
278                     const NameAttrList** value);  // type: "func"
279 
280 // These versions copies the NameAttrList(s).
281 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
282                    NameAttrList* value);  // type: "func"
283 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
284                    std::vector<NameAttrList>* value);  // type: "list(func)"
285 
286 // Look up the attr with name attr_name and set *value to its value.  If no
287 // attr with attr_name is found in node_def, or the attr does not have
288 // a matching type, false is returned.
289 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
290                     std::string* value);  // type: "string"
291 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
292                     int64_t* value);  // type: "int"
293 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
294                     std::vector<int64_t>* value);  // type: "int"
295 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
296                     int32* value);  // type: "int"
297 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
298                     float* value);  // type: "float"
299 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
300                     bool* value);  // type: "bool"
301 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
302                     DataType* value);  // type: "type"
303 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
304                     TensorShape* value);  // type: "shape"
305 
306 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
307                     std::vector<string>* value);  // type: "list(string)"
308 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
309                     std::vector<tstring>* value);  // type: "list(tstring)"
310 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
311                     std::vector<int32>* value);  // type: "list(int)"
312 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
313                     std::vector<float>* value);  // type: "list(float)"
314 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
315                     std::vector<bool>* value);  // type: "list(bool)"
316 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
317                     std::vector<DataType>* value);  // type: "list(type)"
318 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
319                     std::vector<TensorShape> value);  // type: "shape"
320 
321 // Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute
322 // values.
323 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
324                     std::vector<const string*>* value);  // type: "list(string)"
325 bool TryGetNodeAttr(
326     const AttrSlice& attrs, StringPiece attr_name,
327     std::vector<const TensorShapeProto*>* value);  // type: "list(shape)"
328 
329 // Look up the attr with name attr_name and return a reference to its value.
330 // If no attr with attr_name is found in node_def, or the attr does not have
331 // a matching type, a reference to an empty string is returned.
332 // REQUIRES: Must not use the returned value beyond the lifetime of node_def.
333 const std::string& GetNodeAttrString(const AttrSlice& attrs,
334                                      StringPiece attr_name);
335 
336 // Specialization to parse an attribute directly into a Padding enum.
337 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
338                    Padding* value);
339 
340 // Computes the input type for a specific node input.
341 // REQUIRES: ValidateOpDef(op_def).ok()
342 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
343                         int input_port, DataType* input_type);
344 // Computes the input types for a specific node.
345 // REQUIRES: ValidateOpDef(op_def).ok()
346 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
347                          DataTypeVector* inputs);
348 // Computes the output type for a specific node output.
349 // REQUIRES: ValidateOpDef(op_def).ok()
350 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
351                          int output_port, DataType* output_type);
352 // Computes the output types for a specific node.
353 // REQUIRES: ValidateOpDef(op_def).ok()
354 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
355                           DataTypeVector* outputs);
356 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
357                           DataTypeVector* outputs);
358 
359 // Computes the input and output types for a specific node.
360 // REQUIRES: ValidateOpDef(op_def).ok()
361 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
362                          DataTypeVector* inputs, DataTypeVector* outputs);
363 // Computes the number of outputs for a specific node.
364 // REQUIRES: ValidateOpDef(op_def).ok()
365 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
366                          int* num_outputs);
367 
368 // Map a node/op's input/output port_id to arg_id.
369 //
370 // The port_id refers to the n-th tensor of the node, while the arg_id refers to
371 // the n-th arg of the op. These two can be different if an op's arg is a list
372 // of tensors.
373 //
374 // We return -1 for any invalid port_id (i.e., no corresponding arg_id).
375 int OpPortIdToArgId(const NodeDef& node,
376                     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
377                     int port_id);
378 
379 // Validates that the NodeDef:
380 // * Defines all expected attrs from the OpDef.
381 // * All attrs satisfies constraints from the OpDef.
382 // * Has a signature matching SignatureForNode().
383 // etc.
384 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
385 
386 // Computes the mapping from input/output argument name to the
387 // corresponding input/output index range.  For example,
388 // input "foo" corresponds to input indices
389 //   [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
390 // NOTE(mrry): To reduce allocations when the map is used and save
391 // space, the returned `NameRangeMap` objects borrow the input/output
392 // argument names from `op_def`. The `op_def` must outlive the
393 // returned `NameRangeMap` objects.
394 typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
395     NameRangeMap;
396 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
397                          NameRangeMap* inputs, NameRangeMap* outputs);
398 // Adds default values to *node_def for unspecified attrs from op_def.
399 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
400 
401 // Remove attributes from node_def when the value is the default from the
402 // op_def.
403 void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def);
404 
405 // Validates the syntax of a NodeDef provided externally.
406 //
407 // The following is an EBNF-style syntax for NodeDef objects. Note that
408 // Node objects are actually specified as tensorflow::NodeDef protocol buffers,
409 // which contain many other fields that are not (currently) validated.
410 //
411 // Node         = NodeName, Inputs
412 // Inputs       = ( DataInput * ), ( ControlInput * )
413 // DataInput    = NodeName, ( ":", [1-9], [0-9] * ) ?
414 // ControlInput = "^", NodeName
415 // NodeName     = [A-Za-z0-9.], [A-Za-z0-9_./] *
416 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
417 
418 // Returns "status" with formatted NodeDef attached as additional text
419 // in the error message. If 'allow_multiple_formatted_node' is false and there
420 // is already a formatted NodeDef present in 'status', we simply attach the name
421 // of the NodeDef instead of the formatted string.
422 Status AttachDef(const Status& status, const NodeDef& node_def,
423                  bool allow_multiple_formatted_node = false);
424 // Appends the given prefix and suffix to the original node name in order to
425 // make the name unique. If it's an "Enter" node and uniquify_frame_name is
426 // true, use the same way to reset attribute "frame_name".
427 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
428                                 NodeDef* node_def,
429                                 bool uniquify_frame_name = true);
430 
431 // Appends the given prefix to the colocation group name if the name exists
432 // in `to_match`.
433 Status MaybeAddPrefixToColocationConstraints(
434     const std::unordered_set<string>& match, StringPiece prefix,
435     NodeDef* node_def);
436 
437 // Updates the colocation constraint name with the one provided in the map (if
438 // it exists in the map) for node_def.
439 Status MaybeUpdateColocationConstraintsWithMap(
440     const std::map<absl::string_view, absl::string_view>& node_name_map,
441     NodeDef* node_def);
442 
443 }  // namespace tensorflow
444 
445 #endif  // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
446