1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
18
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/gtl/array_slice.h"
32 #include "tensorflow/core/lib/gtl/flatmap.h"
33 #include "tensorflow/core/lib/hash/hash.h"
34 #include "tensorflow/core/platform/hash.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/stringpiece.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/util/padding.h"
40
41 namespace tensorflow {
42
43 class AttrSlice;
44 // We forward declare protos so that kernels don't need to depend on them
45 class OpDef;
46 class AttrValue;
47 class NameAttrList;
48 class TensorProto;
49 class TensorShapeProto;
50
51 // Name of the attribute used to encode node colocation constraints.
52 //
53 // Nodes can be co-located on the same device. Desire for explicit co-location
54 // is described by list(string) attribute containing the name of colocation
55 // groups.
56 extern const char* const kColocationAttrName;
57
58 // String prefix applied to the operation name for colocation constraints.
59 extern const char* const kColocationGroupPrefix;
60
61 // Constants for host CPU staging op for TPUExecute.
62 extern const char* const kTpuExecuteStagingOp;
63 extern const char* const kTpuExecuteStagingNodeName;
64
65 // Produce a human-readable version of a Node or NodeDef that is more concise
66 // than a text-format proto.
67 //
68 // The parameter `max_inputs_in_summary` specifies how many inputs at most to
69 // serialize in the output (in order not to get a string which is overly large).
70 // The value `-1` specifies that all inputs will be shown.
71 std::string SummarizeNodeDef(const NodeDef& node_def,
72 int max_inputs_in_summary = -1);
73 std::string SummarizeAttrs(const NodeDef& node_def);
74 std::string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
75
76 // Produces a formatted string pattern from the node which can uniquely identify
77 // this node upstream to produce an informative error message. The pattern
78 // followed is: {{node <node_name>}}
79 std::string FormatNodeDefForError(const NodeDef& node_def);
80 std::string FormatNodeDefForError(
81 StringPiece node_name, bool has_experimental_debug_info,
82 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
83
84 typedef protobuf::Map<string, AttrValue> AttrValueMap;
85
86 // Adds an attr with name <name> and value <value> to *node_def.
87 // The type of the attr is based on the type of value.
88 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
89 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
90 void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
91 void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
92 void AddNodeAttr(StringPiece name, int32_t value, NodeDef* node_def);
93 void AddNodeAttr(StringPiece name, int64_t value, NodeDef* node_def);
94 void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
95 void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
96 void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
97 void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
98 void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
99 NodeDef* node_def);
100 void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
101 void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
102 void AddNodeAttr(StringPiece name, const NameAttrList& value,
103 NodeDef* node_def);
104 void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
105 NodeDef* node_def);
106 void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
107 NodeDef* node_def);
108 void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
109 NodeDef* node_def);
110 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
111 NodeDef* node_def);
112 void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64_t> value,
113 NodeDef* node_def);
114 void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
115 NodeDef* node_def);
116 void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
117 NodeDef* node_def);
118 void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
119 NodeDef* node_def);
120 void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
121 NodeDef* node_def);
122 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
123 NodeDef* node_def);
124 void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
125 NodeDef* node_def);
126 void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
127 NodeDef* node_def);
128 void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
129 NodeDef* node_def);
130 void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
131 NodeDef* node_def);
132
133 // Version to workaround C++'s "perfect" forwarding not being able to
134 // forward {...} initialization.
135 template <class T>
AddNodeAttr(StringPiece name,std::initializer_list<T> value,NodeDef * node_def)136 void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
137 NodeDef* node_def) {
138 AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
139 }
140
141 // Adds an attr to an attr value map.
142 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
143 void AddAttr(StringPiece name, bool value, AttrValueMap* map);
144
145 class AttrSlice {
146 public:
147 AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
148
149 AttrSlice(); // Empty
150 explicit AttrSlice(const AttrValueMap* a);
151
size()152 int size() const { return attrs()->size(); }
153
154 // Returns the attr with attr_name if found. Otherwise, returns
155 // nullptr.
156 const AttrValue* Find(StringPiece attr_name) const;
157 const AttrValue* FindByString(const std::string& attr_name) const;
158
159 // Returns the attr_value for attr_name if found. Otherwise, returns a
160 // NotFound status.
161 Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
162 Status FindByString(const std::string& attr_name,
163 const AttrValue** attr_value) const;
164
165 // Helper class to avoid allocations in EqualAttrs.
166 // TODO(irving): Will go away once NodeInfo is used.
167 struct Scratch {
168 std::string a;
169 std::string b;
170 };
171
172 // Check if all attrs and attr values match. Does not take defaults into
173 // account.
174 //
175 // TODO(irving): There is a bug in this routine inherited from its
176 // OptimizerCSE::EqualAttrs predecessor. The same tensor attr can be
177 // represented in more than one way as an AttrValue, since TensorProto is
178 // not 1-1. This bug will go away once I replace everything with NodeInfo,
179 // which stores a Tensor object directly. The Scratch object will also go
180 // away.
181 bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
182
183 // If this AttrSlice has an attached NodeDef, summarize it. This is for
184 // error messages only: we intentionally do not provide direct access to the
185 // NodeDef, since it is not always there.
186 std::string SummarizeNode() const;
187
188 // Iteration over all attrs
begin()189 AttrValueMap::const_iterator begin() const { return attrs()->begin(); }
end()190 AttrValueMap::const_iterator end() const { return attrs()->end(); }
191
192 std::string DebugString() const;
193
194 private:
attrs()195 const AttrValueMap* attrs() const {
196 return ndef_ != nullptr ? &ndef_->attr() : attrs_;
197 }
198
199 Status CheckFind(StringPiece attr_name, const AttrValue* attr_value) const;
200
201 const NodeDef* ndef_;
202 const AttrValueMap* attrs_;
203 };
204
205 // Return true if the attr with the name attr_name is defined in node_def.
206 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
207
208 // Look up the attr with name attr_name and set *value to its value. If no
209 // attr with attr_name is found in node_def, or the attr does not have
210 // a matching type, a non-ok status will be returned.
211 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
212 std::string* value); // type: "string"
213 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
214 tstring* value); // type: "tstring"
215 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
216 int64_t* value); // type: "int"
217 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
218 int32* value); // type: "int"
219 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
220 float* value); // type: "float"
221 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
222 bool* value); // type: "bool"
223 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
224 DataType* value); // type: "type"
225 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
226 TensorShapeProto* value); // type: "shape"
227 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
228 TensorShape* value); // type: "shape"
229 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
230 PartialTensorShape* value); // type: "shape"
231 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
232 Tensor* value); // type: "tensor"
233 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
234 std::vector<string>* value); // type "list(string)"
235 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
236 std::vector<tstring>* value); // type "list(tstring)"
237 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
238 std::vector<int64_t>* value); // type "list(int)"
239 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
240 std::vector<int32>* value); // type "list(int)"
241 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
242 std::vector<float>* value); // type "list(float)"
243 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
244 std::vector<bool>* value); // type "list(bool)"
245 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
246 std::vector<DataType>* value); // type "list(type)"
247 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
248 DataTypeVector* value); // type "list(type)"
249 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
250 std::vector<TensorShapeProto>* value); // type "list(shape)"
251 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
252 std::vector<TensorShape>* value); // type "list(shape)"
253 Status GetNodeAttr(
254 const AttrSlice& attrs, StringPiece attr_name,
255 std::vector<PartialTensorShape>* value); // type "list(shape)"
256 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
257 std::vector<Tensor>* value); // type: "list(tensor)"
258
259 template <typename T>
GetNodeAttr(const NodeDef & ndef,absl::string_view attr_name)260 StatusOr<T> GetNodeAttr(const NodeDef& ndef, absl::string_view attr_name) {
261 T val;
262 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, attr_name, &val));
263 return val;
264 }
265
266 // This version avoids copying the TensorProto.
267 // REQUIRES: Must not use *value beyond the lifetime of node_def.
268 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
269 const TensorProto** value); // type: "tensor"
270 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
271 const TensorProto** value); // type: "tensor"
272
273 // This version avoids copying the NameAttrList.
274 // REQUIRES: Must not use *value beyond the lifetime of node_def.
275 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
276 const NameAttrList** value); // type: "func"
277 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
278 const NameAttrList** value); // type: "func"
279
280 // These versions copies the NameAttrList(s).
281 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
282 NameAttrList* value); // type: "func"
283 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
284 std::vector<NameAttrList>* value); // type: "list(func)"
285
286 // Look up the attr with name attr_name and set *value to its value. If no
287 // attr with attr_name is found in node_def, or the attr does not have
288 // a matching type, false is returned.
289 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
290 std::string* value); // type: "string"
291 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
292 int64_t* value); // type: "int"
293 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
294 std::vector<int64_t>* value); // type: "int"
295 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
296 int32* value); // type: "int"
297 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
298 float* value); // type: "float"
299 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
300 bool* value); // type: "bool"
301 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
302 DataType* value); // type: "type"
303 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
304 TensorShape* value); // type: "shape"
305
306 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
307 std::vector<string>* value); // type: "list(string)"
308 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
309 std::vector<tstring>* value); // type: "list(tstring)"
310 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
311 std::vector<int32>* value); // type: "list(int)"
312 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
313 std::vector<float>* value); // type: "list(float)"
314 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
315 std::vector<bool>* value); // type: "list(bool)"
316 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
317 std::vector<DataType>* value); // type: "list(type)"
318 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
319 std::vector<TensorShape> value); // type: "shape"
320
321 // Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute
322 // values.
323 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
324 std::vector<const string*>* value); // type: "list(string)"
325 bool TryGetNodeAttr(
326 const AttrSlice& attrs, StringPiece attr_name,
327 std::vector<const TensorShapeProto*>* value); // type: "list(shape)"
328
329 // Look up the attr with name attr_name and return a reference to its value.
330 // If no attr with attr_name is found in node_def, or the attr does not have
331 // a matching type, a reference to an empty string is returned.
332 // REQUIRES: Must not use the returned value beyond the lifetime of node_def.
333 const std::string& GetNodeAttrString(const AttrSlice& attrs,
334 StringPiece attr_name);
335
336 // Specialization to parse an attribute directly into a Padding enum.
337 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
338 Padding* value);
339
340 // Computes the input type for a specific node input.
341 // REQUIRES: ValidateOpDef(op_def).ok()
342 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
343 int input_port, DataType* input_type);
344 // Computes the input types for a specific node.
345 // REQUIRES: ValidateOpDef(op_def).ok()
346 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
347 DataTypeVector* inputs);
348 // Computes the output type for a specific node output.
349 // REQUIRES: ValidateOpDef(op_def).ok()
350 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
351 int output_port, DataType* output_type);
352 // Computes the output types for a specific node.
353 // REQUIRES: ValidateOpDef(op_def).ok()
354 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
355 DataTypeVector* outputs);
356 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
357 DataTypeVector* outputs);
358
359 // Computes the input and output types for a specific node.
360 // REQUIRES: ValidateOpDef(op_def).ok()
361 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
362 DataTypeVector* inputs, DataTypeVector* outputs);
363 // Computes the number of outputs for a specific node.
364 // REQUIRES: ValidateOpDef(op_def).ok()
365 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
366 int* num_outputs);
367
368 // Map a node/op's input/output port_id to arg_id.
369 //
370 // The port_id refers to the n-th tensor of the node, while the arg_id refers to
371 // the n-th arg of the op. These two can be different if an op's arg is a list
372 // of tensors.
373 //
374 // We return -1 for any invalid port_id (i.e., no corresponding arg_id).
375 int OpPortIdToArgId(const NodeDef& node,
376 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
377 int port_id);
378
379 // Validates that the NodeDef:
380 // * Defines all expected attrs from the OpDef.
381 // * All attrs satisfies constraints from the OpDef.
382 // * Has a signature matching SignatureForNode().
383 // etc.
384 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
385
386 // Computes the mapping from input/output argument name to the
387 // corresponding input/output index range. For example,
388 // input "foo" corresponds to input indices
389 // [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
390 // NOTE(mrry): To reduce allocations when the map is used and save
391 // space, the returned `NameRangeMap` objects borrow the input/output
392 // argument names from `op_def`. The `op_def` must outlive the
393 // returned `NameRangeMap` objects.
394 typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
395 NameRangeMap;
396 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
397 NameRangeMap* inputs, NameRangeMap* outputs);
398 // Adds default values to *node_def for unspecified attrs from op_def.
399 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
400
401 // Remove attributes from node_def when the value is the default from the
402 // op_def.
403 void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def);
404
405 // Validates the syntax of a NodeDef provided externally.
406 //
407 // The following is an EBNF-style syntax for NodeDef objects. Note that
408 // Node objects are actually specified as tensorflow::NodeDef protocol buffers,
409 // which contain many other fields that are not (currently) validated.
410 //
411 // Node = NodeName, Inputs
412 // Inputs = ( DataInput * ), ( ControlInput * )
413 // DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
414 // ControlInput = "^", NodeName
415 // NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
416 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
417
418 // Returns "status" with formatted NodeDef attached as additional text
419 // in the error message. If 'allow_multiple_formatted_node' is false and there
420 // is already a formatted NodeDef present in 'status', we simply attach the name
421 // of the NodeDef instead of the formatted string.
422 Status AttachDef(const Status& status, const NodeDef& node_def,
423 bool allow_multiple_formatted_node = false);
424 // Appends the given prefix and suffix to the original node name in order to
425 // make the name unique. If it's an "Enter" node and uniquify_frame_name is
426 // true, use the same way to reset attribute "frame_name".
427 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
428 NodeDef* node_def,
429 bool uniquify_frame_name = true);
430
431 // Appends the given prefix to the colocation group name if the name exists
432 // in `to_match`.
433 Status MaybeAddPrefixToColocationConstraints(
434 const std::unordered_set<string>& match, StringPiece prefix,
435 NodeDef* node_def);
436
437 // Updates the colocation constraint name with the one provided in the map (if
438 // it exists in the map) for node_def.
439 Status MaybeUpdateColocationConstraintsWithMap(
440 const std::map<absl::string_view, absl::string_view>& node_name_map,
441 NodeDef* node_def);
442
443 } // namespace tensorflow
444
445 #endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
446