1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h" // NOLINT
19 #include "tensorflow/core/framework/fake_input.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30
31 namespace tensorflow {
32 namespace {
33
ToOpDef(const OpDefBuilder & builder)34 OpDef ToOpDef(const OpDefBuilder& builder) {
35 OpRegistrationData op_reg_data;
36 TF_EXPECT_OK(builder.Finalize(&op_reg_data));
37 return op_reg_data.op_def;
38 }
39
ToNodeDef(const string & text)40 NodeDef ToNodeDef(const string& text) {
41 NodeDef node_def;
42 EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
43 return node_def;
44 }
45
ToNodeDef(NodeDefBuilder && builder)46 NodeDef ToNodeDef(NodeDefBuilder&& builder) {
47 NodeDef node_def;
48 TF_EXPECT_OK(builder.Finalize(&node_def));
49 return node_def;
50 }
51
ExpectSuccess(const NodeDef & good,const OpDef & op_def)52 void ExpectSuccess(const NodeDef& good, const OpDef& op_def) {
53 EXPECT_EQ(OkStatus(), ValidateNodeDef(good, op_def))
54 << "NodeDef: " << SummarizeNodeDef(good)
55 << "; OpDef: " << SummarizeOpDef(op_def);
56 }
57
ExpectFailure(const NodeDef & bad,const OpDef & op_def,const string & message)58 void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
59 const string& message) {
60 Status status = ValidateNodeDef(bad, op_def);
61
62 EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad)
63 << "; OpDef: " << SummarizeOpDef(op_def);
64 if (status.ok()) return;
65
66 EXPECT_TRUE(errors::IsInvalidArgument(status))
67 << status << "; NodeDef: " << SummarizeNodeDef(bad)
68 << "; OpDef: " << SummarizeOpDef(op_def);
69
70 LOG(INFO) << "Message: " << status.error_message();
71 EXPECT_TRUE(absl::StrContains(status.ToString(), message))
72 << "NodeDef: " << SummarizeNodeDef(bad)
73 << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
74 << "\nDoes not contain: " << message;
75 }
76
TEST(NodeDefUtilTest,In)77 TEST(NodeDefUtilTest, In) {
78 const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type"));
79 const NodeDef node_def = ToNodeDef(R"pb(
80 name: 'n'
81 op: 'In'
82 input: 'a'
83 attr {
84 key: 'T'
85 value { type: DT_FLOAT }
86 }
87 )pb");
88 ExpectSuccess(node_def, op);
89
90 EXPECT_EQ("{{node n}} = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
91
92 // Mismatching Op names.
93 NodeDef bad = node_def;
94 bad.set_op("Wrong");
95 ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;");
96
97 // Missing attr
98 bad = node_def;
99 bad.clear_attr();
100 ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;");
101
102 // Attr has wrong type
103 bad = node_def;
104 bad.clear_attr();
105 AddNodeAttr("T", 17, &bad);
106 ExpectFailure(
107 bad, op,
108 "AttrValue had value with type 'int' when 'type' expected\n\t for attr "
109 "'T'\n\t; NodeDef: ");
110
111 // Wrong number of inputs
112 bad = node_def;
113 bad.add_input("b");
114 ExpectFailure(
115 bad, op,
116 "NodeDef expected inputs 'float' do not match 2 inputs specified;");
117
118 bad = node_def;
119 bad.clear_input();
120 ExpectFailure(
121 bad, op,
122 "NodeDef expected inputs 'float' do not match 0 inputs specified;");
123
124 // Control inputs must appear after data inputs
125 NodeDef good = node_def;
126 good.add_input("^b");
127 ExpectSuccess(node_def, op);
128
129 bad = node_def;
130 bad.clear_input();
131 bad.add_input("^b");
132 bad.add_input("a");
133 ExpectFailure(bad, op,
134 "Non-control input 'a' after control input "
135 "in NodeDef:");
136
137 bad = node_def;
138 bad.add_input("^b:0");
139 ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:");
140 }
141
TEST(NodeDefUtilTest,Out)142 TEST(NodeDefUtilTest, Out) {
143 const OpDef op =
144 ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype"));
145 const NodeDef node_def = ToNodeDef(R"pb(
146 name: 'n'
147 op: 'Out'
148 attr {
149 key: 'T'
150 value { type: DT_INT32 }
151 }
152 )pb");
153 ExpectSuccess(node_def, op);
154
155 EXPECT_EQ("{{node n}} = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
156
157 // Non-number type.
158 NodeDef bad = node_def;
159 bad.clear_attr();
160 AddNodeAttr("T", DT_STRING, &bad);
161 ExpectFailure(bad, op,
162 "Value for attr 'T' of string is not in the list of allowed "
163 "values: float, double, int32, uint8, int16, int8, complex64, "
164 "int64, qint8, quint8, qint32, bfloat16, uint16, complex128, "
165 "half, uint32, uint64");
166 }
167
TEST(NodeDefUtilTest,Enum)168 TEST(NodeDefUtilTest, Enum) {
169 const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}"));
170 const NodeDef node_def = ToNodeDef(R"pb(
171 name: 'n'
172 op: 'Enum'
173 attr {
174 key: 'e'
175 value { s: 'apple' }
176 }
177 )pb");
178 ExpectSuccess(node_def, op);
179
180 EXPECT_EQ("{{node n}} = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
181
182 NodeDef good = node_def;
183 good.clear_attr();
184 AddNodeAttr("e", "orange", &good);
185 ExpectSuccess(good, op);
186
187 // Non-allowed value.
188 NodeDef bad = node_def;
189 bad.clear_attr();
190 AddNodeAttr("e", "foo", &bad);
191 ExpectFailure(bad, op,
192 "Value for attr 'e' of \"foo\" is not in the list of allowed "
193 "values: \"apple\", \"orange\"");
194 }
195
TEST(NodeDefUtilTest,SameIn)196 TEST(NodeDefUtilTest, SameIn) {
197 const OpDef op = ToOpDef(OpDefBuilder("SameIn")
198 .Input("i: N * T")
199 .Attr("N: int >= 2")
200 .Attr("T: {float,double}"));
201 const NodeDef node_def = ToNodeDef(R"pb(
202 name: 'n'
203 op: 'SameIn'
204 input: 'a'
205 input: 'b'
206 attr {
207 key: 'N'
208 value { i: 2 }
209 }
210 attr {
211 key: 'T'
212 value { type: DT_DOUBLE }
213 }
214 )pb");
215 ExpectSuccess(node_def, op);
216
217 EXPECT_EQ("{{node n}} = SameIn[N=2, T=DT_DOUBLE](a, b)",
218 SummarizeNodeDef(node_def));
219
220 // Illegal type
221 NodeDef bad = ToNodeDef(R"pb(
222 name: 'n'
223 op: 'SameIn'
224 input: 'a'
225 input: 'b'
226 attr {
227 key: 'N'
228 value { i: 2 }
229 }
230 attr {
231 key: 'T'
232 value { type: DT_STRING }
233 }
234 )pb");
235 ExpectFailure(bad, op,
236 "Value for attr 'T' of string is not in the list of allowed "
237 "values: float, double");
238
239 // Too few inputs
240 bad = ToNodeDef(R"pb(
241 name: 'n'
242 op: 'SameIn'
243 input: 'a'
244 input: 'b'
245 attr {
246 key: 'N'
247 value { i: 1 }
248 }
249 attr {
250 key: 'T'
251 value { type: DT_FLOAT }
252 }
253 )pb");
254 ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2");
255 }
256
TEST(NodeDefUtilTest,AnyIn)257 TEST(NodeDefUtilTest, AnyIn) {
258 const OpDef op =
259 ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1"));
260
261 const NodeDef node_def = ToNodeDef(R"pb(
262 name: 'n'
263 op: 'AnyIn'
264 input: 'a'
265 input: 'b'
266 attr {
267 key: 'T'
268 value { list { type: [ DT_INT32, DT_STRING ] } }
269 }
270 )pb");
271 ExpectSuccess(node_def, op);
272
273 EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
274 SummarizeNodeDef(node_def));
275
276 const NodeDef bad = ToNodeDef(R"pb(
277 name: 'n'
278 op: 'AnyIn'
279 input: 'a'
280 attr {
281 key: 'T'
282 value { list {} }
283 }
284 )pb");
285 ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1");
286
287 // With proto3 semantics, an empty value {} is indistinguishable from a value
288 // with an empty list in it. So we simply expect to get a message complaining
289 // about empty list for value {}.
290 const NodeDef bad2 = ToNodeDef(R"pb(
291 name: 'n'
292 op: 'AnyIn'
293 input: 'a'
294 attr {
295 key: 'T'
296 value {}
297 }
298 )pb");
299 ExpectFailure(bad2, op,
300 "Length for attr 'T' of 0 must be at least minimum 1");
301 }
302
TEST(NodeDefUtilTest,Device)303 TEST(NodeDefUtilTest, Device) {
304 const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
305 const NodeDef node_def1 =
306 ToNodeDef(std::move(NodeDefBuilder("d", &op_def1).Device("/cpu:17")));
307 ExpectSuccess(node_def1, op_def1);
308 EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()",
309 SummarizeNodeDef(node_def1));
310
311 const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
312 const NodeDef node_def2 = ToNodeDef(
313 std::move(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")));
314 ExpectSuccess(node_def2, op_def2);
315 EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()",
316 SummarizeNodeDef(node_def2));
317 }
318
ExpectValidSyntax(const NodeDef & good)319 void ExpectValidSyntax(const NodeDef& good) {
320 EXPECT_EQ(OkStatus(), ValidateExternalNodeDefSyntax(good))
321 << "NodeDef: " << SummarizeNodeDef(good);
322 }
323
ExpectInvalidSyntax(const NodeDef & bad,const string & message)324 void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
325 Status status = ValidateExternalNodeDefSyntax(bad);
326
327 ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad);
328
329 EXPECT_TRUE(errors::IsInvalidArgument(status))
330 << status << "; NodeDef: " << SummarizeNodeDef(bad);
331
332 EXPECT_TRUE(absl::StrContains(StringPiece(status.ToString()), message))
333 << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
334 << message;
335 }
336
TEST(NodeDefUtilTest,ValidSyntax)337 TEST(NodeDefUtilTest, ValidSyntax) {
338 const NodeDef node_def = ToNodeDef(R"pb(
339 name: 'n'
340 op: 'AnyIn'
341 input: 'a'
342 input: 'b'
343 attr {
344 key: 'T'
345 value { list { type: [ DT_INT32, DT_STRING ] } }
346 }
347 )pb");
348 ExpectValidSyntax(node_def);
349
350 const NodeDef node_def_namespace = ToNodeDef(R"pb(
351 name: 'n'
352 op: 'Project>AnyIn'
353 input: 'a'
354 input: 'b'
355 attr {
356 key: 'T'
357 value { list { type: [ DT_INT32, DT_STRING ] } }
358 }
359 )pb");
360 ExpectValidSyntax(node_def_namespace);
361
362 const NodeDef node_def_explicit_inputs = ToNodeDef(R"pb(
363 name: 'n'
364 op: 'AnyIn'
365 input: 'a:0'
366 input: 'b:123'
367 attr {
368 key: 'T'
369 value { list { type: [ DT_INT32, DT_STRING ] } }
370 }
371 )pb");
372 ExpectValidSyntax(node_def_explicit_inputs);
373
374 EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
375 SummarizeNodeDef(node_def_explicit_inputs));
376
377 const NodeDef node_def_explicit_inputs_namespace = ToNodeDef(R"pb(
378 name: 'Project>n'
379 op: 'Project>AnyIn'
380 input: 'Project>a:0'
381 input: 'Project>b:123'
382 input: '^Project>c'
383 attr {
384 key: 'T'
385 value { list { type: [ DT_INT32, DT_STRING ] } }
386 }
387 )pb");
388 ExpectValidSyntax(node_def_explicit_inputs_namespace);
389
390 EXPECT_EQ(
391 "{{node Project>n}} = Project>AnyIn[T=[DT_INT32, DT_STRING]]"
392 "(Project>a:0, Project>b:123, ^Project>c)",
393 SummarizeNodeDef(node_def_explicit_inputs_namespace));
394
395 const NodeDef node_def_partial_shape = ToNodeDef(R"pb(
396 name: 'n'
397 op: 'AnyIn'
398 attr {
399 key: 'shp'
400 value {
401 shape {
402 dim { size: -1 }
403 dim { size: 0 }
404 }
405 }
406 }
407 )pb");
408 ExpectValidSyntax(node_def_partial_shape);
409
410 const NodeDef node_def_control_input = ToNodeDef(R"pb(
411 name: 'n-'
412 op: 'AnyIn'
413 input: 'a'
414 input: '^b'
415 attr {
416 key: 'T'
417 value { list { type: [ DT_INT32, DT_STRING ] } }
418 }
419 )pb");
420 ExpectValidSyntax(node_def_control_input);
421
422 const NodeDef node_def_invalid_name = ToNodeDef(R"pb(
423 name: 'n:0'
424 op: 'AnyIn'
425 input: 'a'
426 input: 'b'
427 attr {
428 key: 'T'
429 value { list { type: [ DT_INT32, DT_STRING ] } }
430 }
431 )pb");
432 ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'");
433
434 const NodeDef node_def_internal_name = ToNodeDef(R"pb(
435 name: '_n'
436 op: 'AnyIn'
437 input: 'a'
438 input: 'b'
439 attr {
440 key: 'T'
441 value { list { type: [ DT_INT32, DT_STRING ] } }
442 }
443 )pb");
444 ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
445
446 const NodeDef node_def_slash_in_name = ToNodeDef(R"pb(
447 name: 'n\\'
448 op: 'AnyIn'
449 input: 'a'
450 input: 'b'
451 attr {
452 key: 'T'
453 value { list { type: [ DT_INT32, DT_STRING ] } }
454 }
455 )pb");
456 ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'");
457
458 const NodeDef node_def_internal_input_name = ToNodeDef(R"pb(
459 name: 'n'
460 op: 'AnyIn'
461 input: '_a'
462 input: 'b'
463 attr {
464 key: 'T'
465 value { list { type: [ DT_INT32, DT_STRING ] } }
466 }
467 )pb");
468 ExpectInvalidSyntax(node_def_internal_input_name,
469 "Illegal op input name '_a'");
470
471 const NodeDef node_def_input_name_slash = ToNodeDef(R"pb(
472 name: 'n'
473 op: 'AnyIn'
474 input: 'a\\'
475 input: 'b'
476 attr {
477 key: 'T'
478 value { list { type: [ DT_INT32, DT_STRING ] } }
479 }
480 )pb");
481 ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'");
482
483 const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"pb(
484 name: 'n'
485 op: 'AnyIn'
486 input: 'a'
487 input: '^b:0'
488 attr {
489 key: 'T'
490 value { list { type: [ DT_INT32, DT_STRING ] } }
491 }
492 )pb");
493 ExpectInvalidSyntax(node_def_invalid_control_input_name,
494 "Illegal op input name '^b:0'");
495
496 const NodeDef node_def_control_input_name_slash = ToNodeDef(R"pb(
497 name: 'n'
498 op: 'AnyIn'
499 input: 'a'
500 input: '^b\\'
501 attr {
502 key: 'T'
503 value { list { type: [ DT_INT32, DT_STRING ] } }
504 }
505 )pb");
506 ExpectInvalidSyntax(node_def_control_input_name_slash,
507 "Illegal op input name '^b\\'");
508
509 const NodeDef node_def_data_input_after_control = ToNodeDef(R"pb(
510 name: 'n'
511 op: 'AnyIn'
512 input: '^a'
513 input: 'b'
514 attr {
515 key: 'T'
516 value { list { type: [ DT_INT32, DT_STRING ] } }
517 }
518 )pb");
519 ExpectInvalidSyntax(node_def_data_input_after_control,
520 "All control inputs must follow all data inputs");
521
522 const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"pb(
523 name: 'n'
524 op: 'AnyIn'
525 input: 'a:b'
526 input: 'b'
527 attr {
528 key: 'T'
529 value { list { type: [ DT_INT32, DT_STRING ] } }
530 }
531 )pb");
532 ExpectInvalidSyntax(node_def_data_input_invalid_port,
533 "Illegal op input name 'a:b");
534
535 const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"pb(
536 name: 'n'
537 op: 'AnyIn'
538 input: 'a:00'
539 input: 'b'
540 attr {
541 key: 'T'
542 value { list { type: [ DT_INT32, DT_STRING ] } }
543 }
544 )pb");
545 ExpectInvalidSyntax(node_def_data_input_invalid_port2,
546 "Illegal op input name 'a:00");
547 }
548
TEST(InputTypesForNode,Simple)549 TEST(InputTypesForNode, Simple) {
550 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
551 .Input("a: float")
552 .Input("b: int32")
553 .Output("c: string")
554 .Output("d: bool"));
555 const NodeDef node_def = ToNodeDef(std::move(
556 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
557 DataTypeVector types;
558 EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
559 EXPECT_EQ(types[0], DT_FLOAT);
560 EXPECT_EQ(types[1], DT_INT32);
561
562 DataType type;
563 EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
564 EXPECT_EQ(type, DT_FLOAT);
565 EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
566 EXPECT_EQ(type, DT_INT32);
567 EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
568 }
569
TEST(OutputTypesForNode,Simple)570 TEST(OutputTypesForNode, Simple) {
571 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
572 .Input("a: float")
573 .Input("b: int32")
574 .Output("c: string")
575 .Output("d: bool"));
576 const NodeDef node_def = ToNodeDef(std::move(
577 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
578 DataTypeVector types;
579 EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
580 EXPECT_EQ(types[0], DT_STRING);
581 EXPECT_EQ(types[1], DT_BOOL);
582
583 DataType type;
584 EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
585 EXPECT_EQ(type, DT_STRING);
586 EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
587 EXPECT_EQ(type, DT_BOOL);
588 EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
589 }
590
TEST(OutputTypesForNode,LargeOutput)591 TEST(OutputTypesForNode, LargeOutput) {
592 const OpDef op_def = ToOpDef(OpDefBuilder("TestSplitOp")
593 .Input("value: int64")
594 .Output("output: num_split * int64")
595 .Attr("num_split: int >= 1"));
596 int64_t num_split = 1000000000000;
597 const NodeDef node_def =
598 ToNodeDef(std::move(NodeDefBuilder("test_split_op", &op_def)
599 .Input(FakeInput())
600 .Attr("num_split", num_split)));
601 DataTypeVector types;
602 EXPECT_FALSE(OutputTypesForNode(node_def, op_def, &types).ok());
603 }
604
TEST(OutputTypesForNode_AttrSliceOverload,Simple)605 TEST(OutputTypesForNode_AttrSliceOverload, Simple) {
606 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
607 .Input("a: float")
608 .Input("b: int32")
609 .Output("c: string")
610 .Output("d: bool"));
611 const AttrSlice attr_slice =
612 AttrSlice(ToNodeDef(std::move(NodeDefBuilder("simple", &op_def)
613 .Input(FakeInput())
614 .Input(FakeInput()))));
615 DataTypeVector types;
616 EXPECT_TRUE(OutputTypesForNode(attr_slice, op_def, &types).ok());
617 EXPECT_EQ(types[0], DT_STRING);
618 EXPECT_EQ(types[1], DT_BOOL);
619 }
620
TEST(NameRangesForNodeTest,Simple)621 TEST(NameRangesForNodeTest, Simple) {
622 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
623 .Input("a: float")
624 .Input("b: int32")
625 .Output("c: string")
626 .Output("d: bool"));
627 NameRangeMap inputs, outputs;
628 const NodeDef node_def = ToNodeDef(std::move(
629 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
630 TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
631 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
632 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
633
634 EXPECT_EQ("{{node simple}} = Simple[](a, b)", SummarizeNodeDef(node_def));
635
636 OpDef bad_op_def = op_def;
637 bad_op_def.mutable_input_arg(0)->clear_type();
638 EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok());
639 }
640
TEST(NameRangesForNodeTest,Polymorphic)641 TEST(NameRangesForNodeTest, Polymorphic) {
642 const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic")
643 .Input("a: T")
644 .Input("b: T")
645 .Output("c: T")
646 .Attr("T: type"));
647 NameRangeMap inputs, outputs;
648 const NodeDef node_def1 =
649 ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
650 .Input(FakeInput(DT_INT32))
651 .Input(FakeInput(DT_INT32))));
652 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
653 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
654 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
655 EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)",
656 SummarizeNodeDef(node_def1));
657
658 const NodeDef node_def2 =
659 ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
660 .Input(FakeInput(DT_BOOL))
661 .Input(FakeInput(DT_BOOL))));
662 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
663 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
664 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
665 EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_BOOL](a, b)",
666 SummarizeNodeDef(node_def2));
667 }
668
TEST(NameRangesForNodeTest,NRepeats)669 TEST(NameRangesForNodeTest, NRepeats) {
670 const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats")
671 .Input("a: N * int32")
672 .Input("b: N * T")
673 .Output("c: T")
674 .Output("d: N * string")
675 .Output("e: M * bool")
676 .Attr("N: int")
677 .Attr("M: int")
678 .Attr("T: type"));
679 NameRangeMap inputs, outputs;
680 const NodeDef node_def1 =
681 ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
682 .Input(FakeInput(4, DT_INT32))
683 .Input(FakeInput(4, DT_FLOAT))
684 .Attr("M", 3)));
685 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
686 EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
687 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
688 outputs);
689 EXPECT_EQ(
690 "{{node nr}} = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, "
691 "b:2, b:3)",
692 SummarizeNodeDef(node_def1));
693
694 const NodeDef node_def2 =
695 ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
696 .Input(FakeInput(2, DT_INT32))
697 .Input(FakeInput(2, DT_DOUBLE))
698 .Attr("M", 7)));
699 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
700 EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
701 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
702 outputs);
703 EXPECT_EQ("{{node nr}} = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
704 SummarizeNodeDef(node_def2));
705
706 NodeDef bad_node_def = node_def2;
707 bad_node_def.clear_attr();
708 EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
709 }
710
TEST(NameRangesForNodeTest,TypeList)711 TEST(NameRangesForNodeTest, TypeList) {
712 const OpDef op_def = ToOpDef(OpDefBuilder("TypeList")
713 .Input("a: T1")
714 .Input("b: T2")
715 .Output("c: T2")
716 .Output("d: T3")
717 .Output("e: T1")
718 .Attr("T1: list(type)")
719 .Attr("T2: list(type)")
720 .Attr("T3: list(type)"));
721 NameRangeMap inputs, outputs;
722 const NodeDef node_def1 =
723 ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
724 .Input(FakeInput({DT_BOOL, DT_FLOAT}))
725 .Input(FakeInput(4, DT_FLOAT))
726 .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})));
727 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
728 EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
729 EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
730 outputs);
731 EXPECT_EQ(
732 "{{node tl}} = TypeList[T1=[DT_BOOL, DT_FLOAT],"
733 " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
734 " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
735 SummarizeNodeDef(node_def1));
736
737 const NodeDef node_def2 =
738 ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
739 .Input(FakeInput(7, DT_INT32))
740 .Input(FakeInput({DT_DOUBLE}))
741 .Attr("T3", {DT_DOUBLE, DT_STRING})));
742 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
743 EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
744 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
745 outputs);
746 EXPECT_EQ(
747 "{{node tl}} = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, "
748 "DT_INT32,"
749 " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
750 "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
751 SummarizeNodeDef(node_def2));
752
753 NodeDef bad_node_def = node_def2;
754 bad_node_def.clear_attr();
755 EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
756 }
757
TEST(AddPrefixAndSuffixToNode,Enter)758 TEST(AddPrefixAndSuffixToNode, Enter) {
759 NodeDef node_def;
760 node_def.set_name("enter");
761 node_def.set_op("Enter");
762 AddNodeAttr("frame_name", "test_frame", &node_def);
763 const string prefix = "prefix/";
764 const string suffix = "/suffix";
765 TF_ASSERT_OK(AddPrefixAndSuffixToNode(prefix, suffix, &node_def));
766 EXPECT_EQ("prefix/enter/suffix", node_def.name());
767 string frame_name;
768 TF_ASSERT_OK(GetNodeAttr(node_def, "frame_name", &frame_name));
769 EXPECT_EQ("prefix/test_frame/suffix", frame_name);
770 }
771
TEST(MaybeAddPrefixToColocationConstraints,Basic)772 TEST(MaybeAddPrefixToColocationConstraints, Basic) {
773 NodeDef node_def;
774 node_def.set_name("Identity");
775 node_def.set_op("Identity");
776 AddNodeAttr(kColocationAttrName,
777 {strings::StrCat(kColocationGroupPrefix, "Node1"),
778 strings::StrCat(kColocationGroupPrefix, "Node2"),
779 strings::StrCat(kColocationGroupPrefix, "Node3")},
780 &node_def);
781
782 std::unordered_set<string> match;
783 match.insert("Node1");
784 match.insert("Node3");
785 TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
786 std::vector<string> coloc_constraints;
787 TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints));
788 EXPECT_EQ(
789 coloc_constraints,
790 std::vector<string>({"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"}));
791 }
792
TEST(MaybeAddPrefixToColocationConstraints,NoConstraints)793 TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) {
794 NodeDef node_def;
795 node_def.set_name("Identity");
796 node_def.set_op("Identity");
797
798 std::unordered_set<string> match;
799 match.insert("Node1");
800 match.insert("Node3");
801 TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
802 EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName));
803 }
804
TEST(MaybeUpdateColocationConstraintsWithMap,Basic)805 TEST(MaybeUpdateColocationConstraintsWithMap, Basic) {
806 NodeDef node_def;
807 node_def.set_name("Identity");
808 node_def.set_op("Identity");
809 AddNodeAttr(kColocationAttrName,
810 {strings::StrCat(kColocationGroupPrefix, "Node1"),
811 strings::StrCat(kColocationGroupPrefix, "Node2"),
812 strings::StrCat(kColocationGroupPrefix, "Node3")},
813 &node_def);
814
815 std::map<absl::string_view, absl::string_view> node_map;
816 node_map["Node1"] = "Node4";
817 node_map["Invalid"] = "Node5";
818 TF_ASSERT_OK(MaybeUpdateColocationConstraintsWithMap(node_map, &node_def));
819 std::vector<string> coloc_constraints;
820 TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints));
821 EXPECT_EQ(coloc_constraints,
822 std::vector<string>({"loc:@Node4", "loc:@Node2", "loc:@Node3"}));
823 }
824
TEST(MaybeUpdateColocationConstraintsWithMap,NoConstraints)825 TEST(MaybeUpdateColocationConstraintsWithMap, NoConstraints) {
826 NodeDef node_def;
827 node_def.set_name("Identity");
828 node_def.set_op("Identity");
829
830 std::map<absl::string_view, absl::string_view> node_map;
831 node_map["Node1"] = "Node4";
832 node_map["Invalid"] = "Node5";
833 TF_ASSERT_OK(MaybeUpdateColocationConstraintsWithMap(node_map, &node_def));
834 EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName));
835 }
836
TEST(FormatNodeForErrorTest,Node)837 TEST(FormatNodeForErrorTest, Node) {
838 Graph g(OpRegistry::Global());
839 Node* node;
840 TF_CHECK_OK(NodeBuilder("enter", "NoOp").Finalize(&g, &node));
841 EXPECT_EQ("{{node enter}}", FormatNodeForError(*node));
842 }
843
TEST(FormatNodeForErrorTest,NodeDef)844 TEST(FormatNodeForErrorTest, NodeDef) {
845 NodeDef node_def;
846 node_def.set_name("enter");
847 node_def.set_op("Enter");
848 AddNodeAttr("frame_name", "test_frame", &node_def);
849 EXPECT_EQ("{{node enter}}", FormatNodeDefForError(node_def));
850 }
851
TEST(FormatNodeForErrorTest,NodeDefWithOriginalNames)852 TEST(FormatNodeForErrorTest, NodeDefWithOriginalNames) {
853 NodeDef node_def;
854 node_def.set_name("enter");
855 node_def.set_op("Enter");
856 AddNodeAttr("frame_name", "test_frame", &node_def);
857 *(node_def.mutable_experimental_debug_info()->add_original_node_names()) =
858 "node_name";
859 *(node_def.mutable_experimental_debug_info()->add_original_func_names()) =
860 "func_name";
861 EXPECT_EQ("{{function_node func_name}}{{node node_name}}",
862 FormatNodeDefForError(node_def));
863 *(node_def.mutable_experimental_debug_info()->add_original_node_names()) =
864 "node_name2";
865 *(node_def.mutable_experimental_debug_info()->add_original_func_names()) =
866 "func_name2";
867 EXPECT_EQ(
868 "{{function_node func_name}}{{node node_name}}, "
869 "{{function_node func_name2}}{{node node_name2}}",
870 FormatNodeDefForError(node_def));
871 }
872
TEST(AttachDef,AllowMultipleFormattedNode)873 TEST(AttachDef, AllowMultipleFormattedNode) {
874 NodeDef a;
875 a.set_name("a");
876 NodeDef b;
877 b.set_name("b");
878 Status s = Status(error::CANCELLED, "Error");
879 Status s2 = AttachDef(s, a, true);
880 EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
881 Status s3 = AttachDef(s2, b, true);
882 EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[{{node b}}]]", s3.error_message());
883 }
884
TEST(AttachDef,DisallowMultipleFormattedNode)885 TEST(AttachDef, DisallowMultipleFormattedNode) {
886 NodeDef a;
887 a.set_name("a");
888 NodeDef b;
889 b.set_name("b");
890 Status s = Status(error::CANCELLED, "Error");
891 Status s2 = AttachDef(s, a, false);
892 EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
893 Status s3 = AttachDef(s2, b, false);
894 EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[b]]", s3.error_message());
895 }
896
897 } // namespace
898 } // namespace tensorflow
899