1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/scanner.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/strcat.h"
40 #include "tensorflow/core/platform/stringpiece.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace tensorflow {
44
45 const char* const kColocationAttrName = "_class";
46 const char* const kColocationGroupPrefix = "loc:@";
47 // For TPU distributed rewrite, TPU args are collected and "staged" on the local
48 // host using an IdentityN TF op. Some args may result from a remote source.
49 // When all arg tensors are available, the TPUExecute op can be inovoked. See
50 // DistributedTPURewritePass for more details.
51 const char* const kTpuExecuteStagingOp = "IdentityN";
52 const char* const kTpuExecuteStagingNodeName = "_variable_copy";
53
AttrSlice()54 AttrSlice::AttrSlice() : ndef_(nullptr) {
55 static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
56 attrs_ = kEmptyAttrValueMap;
57 }
58
AttrSlice(const NodeDef & node_def)59 AttrSlice::AttrSlice(const NodeDef& node_def)
60 : ndef_(&node_def), attrs_(&ndef_->attr()) {}
61
AttrSlice(const AttrValueMap * a)62 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
63
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)64 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
65 string ret;
66
67 // We sort the attrs so the output is deterministic.
68 std::vector<string> attr_names;
69 attr_names.reserve(attrs.size());
70 for (const auto& attr : attrs) {
71 attr_names.push_back(attr.first);
72 }
73 std::sort(attr_names.begin(), attr_names.end());
74 bool first = true;
75 for (const string& attr_name : attr_names) {
76 if (!first) strings::StrAppend(&ret, ", ");
77 first = false;
78 strings::StrAppend(&ret, attr_name, "=",
79 SummarizeAttrValue(*attrs.Find(attr_name)));
80 }
81
82 // Consider the device to be a final attr with name "_device".
83 if (!device.empty()) {
84 if (!first) strings::StrAppend(&ret, ", ");
85 first = false;
86 strings::StrAppend(&ret, "_device=\"", device, "\"");
87 }
88 return ret;
89 }
90
SummarizeNode() const91 string AttrSlice::SummarizeNode() const {
92 return ndef_ ? SummarizeNodeDef(*ndef_)
93 : strings::StrCat(
94 "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
95 }
96
DebugString() const97 string AttrSlice::DebugString() const {
98 std::vector<string> attr_key_vals;
99 attr_key_vals.reserve(attrs_->size());
100 for (const auto& it : *this) {
101 const string& name = it.first;
102 const AttrValue& attr_value = it.second;
103 attr_key_vals.push_back(
104 absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
105 }
106 return absl::StrJoin(attr_key_vals, ", ");
107 }
108
SummarizeNodeDef(const NodeDef & node_def,int max_inputs_in_summary)109 string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
110 string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
111 " = ", node_def.op(), "[");
112 strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
113 strings::StrAppend(&ret, "](");
114
115 // Output inputs, including control inputs, verbatim.
116 bool first = true;
117 for (const string& input : node_def.input()) {
118 if (!first) strings::StrAppend(&ret, ", ");
119 first = false;
120 if (max_inputs_in_summary-- == 0) {
121 strings::StrAppend(&ret, "...");
122 break;
123 }
124 strings::StrAppend(&ret, input);
125 }
126 strings::StrAppend(&ret, ")");
127 return ret;
128 }
129
SummarizeAttrs(const NodeDef & node_def)130 string SummarizeAttrs(const NodeDef& node_def) {
131 return SummarizeAttrsHelper(node_def, node_def.device());
132 }
133
FormatNodeDefForError(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)134 string FormatNodeDefForError(
135 StringPiece node_name, bool has_experimental_debug_info,
136 const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
137 return !has_experimental_debug_info ||
138 experimental_debug_info.original_node_names().empty()
139 ? errors::FormatNodeNameForError(string(node_name))
140 : errors::FormatNodeNamesForError(
141 experimental_debug_info.original_node_names());
142 }
143
FormatNodeDefForError(const NodeDef & node_def)144 string FormatNodeDefForError(const NodeDef& node_def) {
145 return FormatNodeDefForError(node_def.name(),
146 node_def.has_experimental_debug_info(),
147 node_def.experimental_debug_info());
148 }
149
Find(StringPiece attr_name) const150 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
151 // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
152 // requires that the keys used for lookups have type 'const string&'. Because
153 // this method takes a StringPiece, it is necessary to allocate a temporary
154 // string, copy attr_name to it, and then use that temporary string for the
155 // lookup. This causes an excessive number of short-lived allocations, and for
156 // large graphs, this can be a significant cost.
157 //
158 // Because most nodes have a small number of attributes, a simple linear scan
159 // is generally more efficient than a hashed lookup. If google::protobuf::Map
160 // changes so that it supports efficient lookups using StringPiece instead of
161 // const string&, then this code could be changed to use attrs_->find() again.
162
163 for (const auto& attr : *attrs_) {
164 if (attr.first == attr_name) {
165 return &attr.second;
166 }
167 }
168 return nullptr;
169 }
170
FindByString(const string & attr_name) const171 const AttrValue* AttrSlice::FindByString(const string& attr_name) const {
172 auto iter = attrs_->find(attr_name);
173 if (iter != attrs_->end()) {
174 return &iter->second;
175 } else {
176 return nullptr;
177 }
178 }
179
Find(StringPiece attr_name,const AttrValue ** attr_value) const180 Status AttrSlice::Find(StringPiece attr_name,
181 const AttrValue** attr_value) const {
182 *attr_value = Find(attr_name);
183 if (*attr_value != nullptr) {
184 return Status::OK();
185 }
186 Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
187 // Skip AttachDef for internal attrs since it is a little bit
188 // expensive and it is common for them to correctly not be included
189 // in a NodeDef.
190 if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
191 s = AttachDef(s, *ndef_);
192 }
193 return s;
194 }
195
EqualAttrs(AttrSlice other,Scratch * scratch) const196 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
197 if (size() != other.size()) return false;
198
199 for (const auto& attr : *other.attrs_) {
200 auto iter = attrs_->find(attr.first);
201 if (iter == attrs_->end()) return false;
202 // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
203 // TensorProto is a nonunique representation of Tensor. This bug will go
204 // away once AttrSlice switches over to NodeInfo.
205 iter->second.SerializeToString(&scratch->a);
206 attr.second.SerializeToString(&scratch->b);
207 if (scratch->a != scratch->b) return false;
208 }
209 return true;
210 }
211
212 // The ... is to allow the caller to inject some value validation code. Use
213 // just ; if no additional validation code is needed.
214 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
215 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
216 TYPE* value) { \
217 const AttrValue* attr_value; \
218 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
219 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
220 const auto& v = attr_value->FIELD(); \
221 __VA_ARGS__; \
222 *value = CAST; \
223 return Status::OK(); \
224 } \
225 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
226 std::vector<TYPE>* value) { \
227 const AttrValue* attr_value; \
228 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
229 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
230 value->reserve(attr_value->list().FIELD().size()); \
231 for (const auto& v : attr_value->list().FIELD()) { \
232 __VA_ARGS__; \
233 value->APPEND_OP(CAST); \
234 } \
235 return Status::OK(); \
236 }
237
238 #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
239 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
240 TYPE* value) { \
241 const AttrValue* attr_value = attrs.Find(attr_name); \
242 if (attr_value == nullptr) { \
243 return false; \
244 } \
245 Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \
246 if (!s.ok()) { \
247 return false; \
248 } \
249 const auto& v = attr_value->FIELD(); \
250 __VA_ARGS__; \
251 *value = CAST; \
252 return true; \
253 } \
254 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
255 std::vector<TYPE>* value) { \
256 const AttrValue* attr_value = attrs.Find(attr_name); \
257 if (attr_value == nullptr) { \
258 return false; \
259 } \
260 Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \
261 if (!s.ok()) { \
262 return false; \
263 } \
264 value->reserve(attr_value->list().FIELD().size()); \
265 for (const auto& v : attr_value->list().FIELD()) { \
266 __VA_ARGS__; \
267 value->APPEND_OP(CAST); \
268 } \
269 return true; \
270 }
271 DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
272 DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
273 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
274 DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
275 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
276 DEFINE_TRY_GET_ATTR(int64, i, "int", emplace_back, v, ;)
277 DEFINE_GET_ATTR(
278 int32, i, "int", emplace_back, static_cast<int32>(v),
279 if (static_cast<int64>(static_cast<int32>(v)) != v) {
280 return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
281 " out of range for an int32");
282 })
283 DEFINE_TRY_GET_ATTR(
284 int32, i, "int", emplace_back, static_cast<int32>(v),
285 if (static_cast<int64>(static_cast<int32>(v)) != v) {
286 static int log_counter = 0;
287 if (log_counter < 10) {
288 log_counter++;
289 LOG(WARNING) << "Attr " << attr_name << " has value " << v
290 << " out of range for an int32";
291 }
292 return false;
293 })
294 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
295 DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
296 // std::vector<bool> specialization does not have emplace_back until
297 // c++14, so we have to use push_back (see
298 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
299 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
300 DEFINE_TRY_GET_ATTR(bool, b, "bool", push_back, v, ;)
301 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
302 ;)
303 DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
304 static_cast<DataType>(v),
305 ;)
306 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
307 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
308 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
309 DEFINE_TRY_GET_ATTR(
310 TensorShape, shape, "shape", emplace_back, TensorShape(v),
311 if (!TensorShape::IsValidShape(v).ok()) {
312 static int log_counter = 0;
313 if (log_counter < 10) {
314 log_counter++;
315 LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
316 << v.DebugString();
317 }
318 return false;
319 })
320 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
321 PartialTensorShape(v),
322 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
323 DEFINE_GET_ATTR(
324 Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
325 return errors::InvalidArgument("Attr ", attr_name, " has value ",
326 v.ShortDebugString(),
327 " that can't be converted to a Tensor");
328 })
329 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
330 #undef DEFINE_GET_ATTR
331
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)332 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
333 return node_def.attr().find(string(attr_name)) != node_def.attr().end();
334 }
335
336 static const string& kEmptyString = *new string();
337
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)338 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
339 const AttrValue* attr_value = attrs.Find(attr_name);
340 if (attr_value == nullptr) {
341 return kEmptyString;
342 }
343 Status s = AttrValueHasType(*attr_value, "string");
344 if (!s.ok()) {
345 return kEmptyString;
346 }
347 return attr_value->s();
348 }
349
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const string * > * value)350 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
351 std::vector<const string*>* value) {
352 const AttrValue* attr_value = attrs.Find(attr_name);
353 if (attr_value == nullptr) {
354 return false;
355 }
356 Status s = AttrValueHasType(*attr_value, "list(string)");
357 if (!s.ok()) {
358 return false;
359 }
360 value->reserve(attr_value->list().s().size());
361 for (const auto& v : attr_value->list().s()) {
362 value->push_back(&v);
363 }
364 return true;
365 }
366
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const TensorShapeProto * > * value)367 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
368 std::vector<const TensorShapeProto*>* value) {
369 const AttrValue* attr_value = attrs.Find(attr_name);
370 if (attr_value == nullptr) {
371 return false;
372 }
373 Status s = AttrValueHasType(*attr_value, "list(shape)");
374 if (!s.ok()) {
375 return false;
376 }
377 value->reserve(attr_value->list().shape().size());
378 for (const auto& v : attr_value->list().shape()) {
379 value->push_back(&v);
380 }
381 return true;
382 }
383
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)384 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
385 DataTypeVector* value) {
386 const AttrValue* attr_value;
387 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
388 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
389 for (const auto& v : attr_value->list().type()) {
390 value->push_back(static_cast<DataType>(v));
391 }
392 return Status::OK();
393 }
394
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)395 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
396 const TensorProto** value) {
397 const AttrValue* attr_value;
398 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
399 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
400 *value = &attr_value->tensor();
401 return Status::OK();
402 }
403
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)404 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
405 const TensorProto** value) {
406 const AttrValue* attr_value = attrs.Find(attr_name);
407 if (attr_value == nullptr) {
408 return false;
409 }
410 Status s = AttrValueHasType(*attr_value, "tensor");
411 if (!s.ok()) {
412 return false;
413 }
414 *value = &attr_value->tensor();
415 return true;
416 }
417
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)418 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
419 const NameAttrList** value) {
420 const AttrValue* attr_value;
421 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
422 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
423 *value = &attr_value->func();
424 return Status::OK();
425 }
426
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)427 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
428 const NameAttrList** value) {
429 const AttrValue* attr_value = attrs.Find(attr_name);
430 if (attr_value == nullptr) {
431 return false;
432 }
433 Status s = AttrValueHasType(*attr_value, "func");
434 if (!s.ok()) {
435 return false;
436 }
437 *value = &attr_value->func();
438 return true;
439 }
440
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,Padding * value)441 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
442 Padding* value) {
443 string str_value;
444 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value));
445 return GetPaddingFromString(str_value, value);
446 }
447
448 namespace { // Helper for InOutTypesForNode().
449
450 template <class NodeDefOrAttrSlice>
AddArgToSig(const NodeDefOrAttrSlice & node_or_attrs,const OpDef::ArgDef & arg_def,DataTypeVector * sig)451 Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
452 const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
453 const int original_size = sig->size();
454 if (!arg_def.number_attr().empty()) {
455 // Same type repeated "repeats" times.
456 int64_t repeats = -1;
457 TF_RETURN_IF_ERROR(
458 GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
459 // We can't handle outputs that are larger than int32 sizes.
460 if (static_cast<int64>(static_cast<int32>(repeats)) != repeats) {
461 return errors::InvalidArgument("Number of outputs is too big: ", repeats);
462 }
463 if (repeats < 0) {
464 return errors::InvalidArgument("Value for number_attr() ", repeats,
465 " < 0");
466 }
467
468 if (!arg_def.type_attr().empty()) {
469 DataType dtype;
470 TF_RETURN_IF_ERROR(
471 GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
472 for (int i = 0; i < repeats; ++i) {
473 sig->push_back(dtype);
474 }
475 } else if (arg_def.type() != DT_INVALID) {
476 for (int i = 0; i < repeats; ++i) {
477 sig->push_back(arg_def.type());
478 }
479 } else {
480 return errors::InvalidArgument("Missing type or type_attr field in ",
481 arg_def.ShortDebugString());
482 }
483 } else if (!arg_def.type_attr().empty()) {
484 const AttrValue* attr_value;
485 TF_RETURN_IF_ERROR(
486 AttrSlice(node_or_attrs).Find(arg_def.type_attr(), &attr_value));
487 sig->push_back(attr_value->type());
488 } else if (!arg_def.type_list_attr().empty()) {
489 const AttrValue* attr_value;
490 TF_RETURN_IF_ERROR(
491 AttrSlice(node_or_attrs).Find(arg_def.type_list_attr(), &attr_value));
492 for (int dtype : attr_value->list().type()) {
493 sig->push_back(static_cast<DataType>(dtype));
494 }
495 } else if (arg_def.type() != DT_INVALID) {
496 sig->push_back(arg_def.type());
497 } else {
498 return errors::InvalidArgument("No type fields in ",
499 arg_def.ShortDebugString());
500 }
501 if (arg_def.is_ref()) {
502 // For all types that were added by this function call, make them refs.
503 for (size_t i = original_size; i < sig->size(); ++i) {
504 if (IsRefType((*sig)[i])) {
505 return errors::InvalidArgument(
506 "Requested reference to a reference type: ",
507 arg_def.ShortDebugString());
508 }
509 (*sig)[i] = MakeRefType((*sig)[i]);
510 }
511 }
512 return Status::OK();
513 }
514
515 } // namespace
516
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)517 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
518 int input_port, DataType* input_type) {
519 DataTypeVector input_types;
520 for (const auto& arg : op_def.input_arg()) {
521 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
522 int input_types_size = input_types.size();
523 if (input_types_size > input_port) {
524 const DataType dtype = input_types[input_port];
525 *input_type = dtype;
526 return Status::OK();
527 }
528 }
529 return errors::InvalidArgument("Input ", input_port, " not found for node ",
530 node_def.name());
531 }
532
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)533 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
534 DataTypeVector* inputs) {
535 for (const auto& arg : op_def.input_arg()) {
536 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
537 }
538 return Status::OK();
539 }
540
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)541 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
542 int output_port, DataType* output_type) {
543 DataTypeVector output_types;
544 for (const auto& arg : op_def.output_arg()) {
545 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
546 int output_types_size = output_types.size();
547 if (output_types_size > output_port) {
548 const DataType dtype = output_types[output_port];
549 *output_type = dtype;
550 return Status::OK();
551 }
552 }
553 return errors::InvalidArgument("Output ", output_port, " not found for node ",
554 node_def.name());
555 }
556
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)557 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
558 DataTypeVector* outputs) {
559 for (const auto& arg : op_def.output_arg()) {
560 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
561 }
562 return Status::OK();
563 }
564
OutputTypesForNode(const AttrSlice & attrs,const OpDef & op_def,DataTypeVector * outputs)565 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
566 DataTypeVector* outputs) {
567 for (const auto& arg : op_def.output_arg()) {
568 TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
569 }
570 return Status::OK();
571 }
572
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)573 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
574 DataTypeVector* inputs, DataTypeVector* outputs) {
575 TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
576 return OutputTypesForNode(node_def, op_def, outputs);
577 }
578
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)579 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
580 int* num_outputs) {
581 DataTypeVector outputs;
582 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
583 *num_outputs = outputs.size();
584 return Status::OK();
585 }
586
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)587 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
588 if (node_def.op() != op_def.name()) {
589 return errors::InvalidArgument(
590 "NodeDef op '", node_def.op(), "' does not match ",
591 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
592 }
593
594 bool seen_control = false;
595 size_t num_inputs = 0;
596 // TODO(josh11b): Unify the input field validation.
597 for (const string& input : node_def.input()) {
598 if (absl::StartsWith(input, "^")) {
599 seen_control = true;
600 if (input.find(':') != string::npos) {
601 return errors::InvalidArgument("Control input '", input,
602 "' must not have ':' in NodeDef: ",
603 FormatNodeDefForError(node_def));
604 }
605 } else if (seen_control) {
606 return errors::InvalidArgument("Non-control input '", input,
607 "' after control input in NodeDef: ",
608 FormatNodeDefForError(node_def));
609 } else {
610 ++num_inputs;
611 }
612 }
613
614 std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
615 for (const auto& attr : op_def.attr()) {
616 if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
617 return errors::InvalidArgument("OpDef has duplicate attr name '",
618 attr.name(),
619 "': ", SummarizeOpDef(op_def));
620 }
621 }
622 for (const auto& attr : node_def.attr()) {
623 // Allow internal optional attributes with names starting with "_".
624 if (absl::StartsWith(attr.first, "_")) {
625 continue;
626 }
627 auto iter = op_attrs.find(attr.first);
628 if (iter == op_attrs.end()) {
629 LOG_EVERY_N_SEC(ERROR, 5)
630 << "NodeDef mentions attribute " << attr.first
631 << " which is not in the op definition: " << SummarizeOpDef(op_def)
632 << " This may be expected if your graph generating binary is newer "
633 << " than this binary. Unknown attributes will be ignored."
634 << " NodeDef: " << FormatNodeDefForError(node_def);
635 continue;
636 }
637
638 // If attr value is placeholder, do not check it.
639 if (attr.second.placeholder().empty()) {
640 TF_RETURN_WITH_CONTEXT_IF_ERROR(
641 ValidateAttrValue(attr.second, *iter->second),
642 "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
643 SummarizeOpDef(op_def));
644 }
645
646 // Keep track of which attr names have (not) been found in the NodeDef.
647 op_attrs.erase(iter);
648 }
649
650 // Were all attrs in the OpDef found in the NodeDef?
651 if (!op_attrs.empty()) {
652 string attrs;
653 for (const auto& attr_pair : op_attrs) {
654 if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
655 strings::StrAppend(&attrs, attr_pair.first);
656 }
657 return errors::InvalidArgument(
658 "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
659 "' from ", SummarizeOpDef(op_def),
660 "; NodeDef: ", FormatNodeDefForError(node_def));
661 }
662
663 // Validate the number of inputs.
664 DataTypeVector inputs, outputs;
665 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
666
667 if (num_inputs != inputs.size()) {
668 return errors::InvalidArgument(
669 "NodeDef expected inputs '", DataTypeVectorString(inputs),
670 "' do not match ", num_inputs, " inputs specified; ",
671 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
672 }
673
674 return Status::OK();
675 }
676
677 namespace { // Helpers for NameRangesForNode()
678
ComputeArgRange(const AttrSlice & attrs,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)679 Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
680 const OpDef& op_def, int* num) {
681 if (!arg_def.number_attr().empty()) {
682 // Same type repeated "num" times.
683 return GetNodeAttr(attrs, arg_def.number_attr(), num);
684 } else if (!arg_def.type_list_attr().empty()) {
685 const AttrValue* attr_value;
686 TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
687 *num = attr_value->list().type_size();
688 } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
689 *num = 1;
690 } else {
691 return errors::InvalidArgument(
692 "Argument '", arg_def.name(),
693 "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
694 }
695 return Status::OK();
696 }
697
NameRangesHelper(const AttrSlice & attrs,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)698 Status NameRangesHelper(const AttrSlice& attrs,
699 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
700 const OpDef& op_def, NameRangeMap* result) {
701 int start = 0;
702 int num;
703 for (const auto& arg : args) {
704 TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
705 (*result)[arg.name()] = std::make_pair(start, start + num);
706 start += num;
707 }
708 return Status::OK();
709 }
710
711 } // namespace
712
NameRangesForNode(const AttrSlice & attrs,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)713 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
714 NameRangeMap* inputs, NameRangeMap* outputs) {
715 if (inputs != nullptr) {
716 TF_RETURN_IF_ERROR(
717 NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
718 }
719 if (outputs != nullptr) {
720 return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
721 }
722 return Status::OK();
723 }
724
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)725 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
726 for (const auto& attr_def : op_def.attr()) {
727 AttrSlice attrs(*node_def);
728 if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
729 AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
730 }
731 }
732 }
733
734 namespace {
735
736 using ::tensorflow::tstring;
737 using ::tensorflow::strings::Scanner;
738
IsValidNodeName(StringPiece sp)739 bool IsValidNodeName(StringPiece sp) {
740 Scanner scanner(sp);
741 scanner.One(Scanner::LETTER_DIGIT_DOT)
742 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
743
744 while (true) {
745 if (!scanner.GetResult()) // Some error in previous iteration.
746 return false;
747 if (scanner.empty()) // No error, but nothing left, good.
748 return true;
749
750 // Absorb another name/namespace, starting with a '>'
751 scanner.One(Scanner::RANGLE)
752 .One(Scanner::LETTER_DIGIT_DOT)
753 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
754 }
755 }
756
IsValidDataInputName(StringPiece sp)757 bool IsValidDataInputName(StringPiece sp) {
758 // Data inputs are op_name, op_name:0, or op_name:12345.
759 Scanner scan(sp);
760 scan.One(Scanner::LETTER_DIGIT_DOT)
761 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
762
763 while (true) {
764 if (!scan.GetResult()) // Some error in previous iteration.
765 return false;
766 if (scan.empty()) // No error, but nothing left, good.
767 return true;
768
769 if (scan.Peek() == ':') { // Absorb identifier after the colon
770 scan.OneLiteral(":");
771 if (scan.Peek() == '0') {
772 scan.OneLiteral("0"); // :0
773 } else {
774 scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
775 }
776 } else {
777 // Absorb another name/namespace, starting with a '>'
778 scan.One(Scanner::RANGLE)
779 .One(Scanner::LETTER_DIGIT_DOT)
780 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
781 }
782 }
783 }
784
IsValidControlInputName(StringPiece sp)785 bool IsValidControlInputName(StringPiece sp) {
786 Scanner scan(sp);
787 scan.OneLiteral("^")
788 .One(Scanner::LETTER_DIGIT_DOT)
789 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
790
791 while (true) {
792 if (!scan.GetResult()) // Some error in previous iteration.
793 return false;
794 if (scan.empty()) // No error, but nothing left, good.
795 return true;
796
797 // Absorb another name/namespace, starting with a '>'
798 scan.One(Scanner::RANGLE)
799 .One(Scanner::LETTER_DIGIT_DOT)
800 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
801 }
802 }
803
804 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
805
806 } // namespace
807
ValidateOpInput(const string & input_name,bool * is_control_input)808 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
809 *is_control_input = false;
810 if (IsValidDataInputName(input_name)) {
811 return Status::OK();
812 } else if (IsValidControlInputName(input_name)) {
813 *is_control_input = true;
814 return Status::OK();
815 } else {
816 return errors::InvalidArgument("Illegal op input name '", input_name, "'");
817 }
818 }
819
ValidateNodeName(const string & node_name)820 Status ValidateNodeName(const string& node_name) {
821 if (IsValidNodeName(node_name)) {
822 return Status::OK();
823 } else {
824 return errors::InvalidArgument("Illegal op name '", node_name, "'");
825 }
826 }
827
ValidateExternalNodeDefSyntax(const NodeDef & node_def)828 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
829 Status s = ValidateNodeName(node_def.name());
830 if (!s.ok()) {
831 return AttachDef(s, node_def);
832 }
833 bool in_control_inputs = false;
834 for (const string& input_name : node_def.input()) {
835 bool is_control_input;
836 s = ValidateOpInput(input_name, &is_control_input);
837 if (!s.ok()) {
838 return AttachDef(s, node_def);
839 }
840
841 if (in_control_inputs && !is_control_input) {
842 return AttachDef(errors::InvalidArgument(
843 "All control inputs must follow all data inputs"),
844 node_def);
845 }
846 in_control_inputs = is_control_input;
847 }
848 return Status::OK();
849 }
850
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)851 Status AttachDef(const Status& status, const NodeDef& node_def,
852 bool allow_multiple_formatted_node) {
853 Status ret = status;
854 string node_error;
855 if (!allow_multiple_formatted_node &&
856 status.error_message().find("{{node ") != string::npos) {
857 node_error = node_def.name();
858 } else {
859 node_error = FormatNodeDefForError(node_def);
860 }
861 errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
862 return ret;
863 }
864
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)865 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
866 node_def->mutable_attr()->insert(
867 AttrValueMap::value_type(string(name), value));
868 }
869
AddNodeAttr(StringPiece name,AttrValue && value,NodeDef * node_def)870 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
871 (*node_def->mutable_attr())[string(name)] = std::move(value);
872 }
873
874 #define ADD_NODE_ATTR(T) \
875 void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
876 AttrValue attr_value; \
877 SetAttrValue(value, &attr_value); \
878 AddNodeAttr(name, attr_value, node_def); \
879 }
880 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)881 ADD_NODE_ATTR(const char*)
882 ADD_NODE_ATTR(int32_t)
883 ADD_NODE_ATTR(int64_t)
884 ADD_NODE_ATTR(float)
885 ADD_NODE_ATTR(double)
886 ADD_NODE_ATTR(bool)
887 ADD_NODE_ATTR(DataType)
888 ADD_NODE_ATTR(const PartialTensorShape&)
889 ADD_NODE_ATTR(const Tensor&)
890 ADD_NODE_ATTR(const TensorProto&)
891 ADD_NODE_ATTR(const NameAttrList&)
892 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
893 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
894 ADD_NODE_ATTR(gtl::ArraySlice<string>)
895 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
896 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
897 ADD_NODE_ATTR(gtl::ArraySlice<float>)
898 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
899 ADD_NODE_ATTR(const std::vector<bool>&)
900 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
901 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
902 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
903 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
904 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
905 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
906 #undef ADD_NODE_ATTR
907
908 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
909 map->insert(AttrValueMap::value_type(string(name), value));
910 }
911
912 #define ADD_ATTR(T) \
913 void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
914 AttrValue attr_value; \
915 SetAttrValue(value, &attr_value); \
916 AddAttr(name, attr_value, map); \
917 }
ADD_ATTR(bool)918 ADD_ATTR(bool)
919 #undef ADD_ATTR
920
921 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
922 NodeDef* node_def, bool uniquify_frame_name) {
923 node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
924
925 // Update frame name to avoid multiple LoopCond nodes in one frame.
926 if (uniquify_frame_name &&
927 (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
928 string frame_name;
929 TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
930 AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
931 frame_name = strings::StrCat(prefix, frame_name, suffix);
932 attr.set_s(frame_name);
933 }
934
935 return Status::OK();
936 }
937
MaybeAddPrefixToColocationConstraints(const std::unordered_set<string> & match,StringPiece prefix,NodeDef * node_def)938 Status MaybeAddPrefixToColocationConstraints(
939 const std::unordered_set<string>& match, StringPiece prefix,
940 NodeDef* node_def) {
941 auto attr = node_def->mutable_attr()->find(kColocationAttrName);
942 if (attr == node_def->mutable_attr()->end()) {
943 return Status::OK();
944 }
945 auto constraints_list = attr->second.mutable_list();
946 auto constraints_size = constraints_list->s_size();
947 for (size_t i = 0; i < constraints_size; ++i) {
948 StringPiece original(constraints_list->s(i));
949 if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
950 if (match.find(string(original)) != match.end()) {
951 (*constraints_list->mutable_s(i)) =
952 strings::StrCat(kColocationGroupPrefix, prefix, original);
953 }
954 }
955 }
956 return Status::OK();
957 }
958
MaybeUpdateColocationConstraintsWithMap(const std::map<absl::string_view,absl::string_view> & node_name_map,NodeDef * node_def)959 Status MaybeUpdateColocationConstraintsWithMap(
960 const std::map<absl::string_view, absl::string_view>& node_name_map,
961 NodeDef* node_def) {
962 auto attr = node_def->mutable_attr()->find(kColocationAttrName);
963 if (attr == node_def->mutable_attr()->end()) {
964 return Status::OK();
965 }
966 auto constraints_list = attr->second.mutable_list();
967 auto constraints_size = constraints_list->s_size();
968 for (size_t i = 0; i < constraints_size; ++i) {
969 StringPiece original(constraints_list->s(i));
970 if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
971 if (node_name_map.find(original) != node_name_map.end()) {
972 (*constraints_list->mutable_s(i)) =
973 strings::StrCat(kColocationGroupPrefix, node_name_map.at(original));
974 }
975 }
976 }
977 return Status::OK();
978 }
979
980 } // namespace tensorflow
981