1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h" // NOLINT
19 #include "tensorflow/core/framework/fake_input.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30
31 namespace tensorflow {
32 namespace {
33
ToOpDef(const OpDefBuilder & builder)34 OpDef ToOpDef(const OpDefBuilder& builder) {
35 OpRegistrationData op_reg_data;
36 TF_EXPECT_OK(builder.Finalize(&op_reg_data));
37 return op_reg_data.op_def;
38 }
39
ToNodeDef(const string & text)40 NodeDef ToNodeDef(const string& text) {
41 NodeDef node_def;
42 EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
43 return node_def;
44 }
45
ToNodeDef(NodeDefBuilder && builder)46 NodeDef ToNodeDef(NodeDefBuilder&& builder) {
47 NodeDef node_def;
48 TF_EXPECT_OK(builder.Finalize(&node_def));
49 return node_def;
50 }
51
ExpectSuccess(const NodeDef & good,const OpDef & op_def)52 void ExpectSuccess(const NodeDef& good, const OpDef& op_def) {
53 EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def))
54 << "NodeDef: " << SummarizeNodeDef(good)
55 << "; OpDef: " << SummarizeOpDef(op_def);
56 }
57
ExpectFailure(const NodeDef & bad,const OpDef & op_def,const string & message)58 void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
59 const string& message) {
60 Status status = ValidateNodeDef(bad, op_def);
61
62 EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad)
63 << "; OpDef: " << SummarizeOpDef(op_def);
64 if (status.ok()) return;
65
66 EXPECT_TRUE(errors::IsInvalidArgument(status))
67 << status << "; NodeDef: " << SummarizeNodeDef(bad)
68 << "; OpDef: " << SummarizeOpDef(op_def);
69
70 LOG(INFO) << "Message: " << status.error_message();
71 EXPECT_TRUE(absl::StrContains(status.ToString(), message))
72 << "NodeDef: " << SummarizeNodeDef(bad)
73 << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
74 << "\nDoes not contain: " << message;
75 }
76
TEST(NodeDefUtilTest,In)77 TEST(NodeDefUtilTest, In) {
78 const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type"));
79 const NodeDef node_def = ToNodeDef(R"proto(
80 name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } }
81 )proto");
82 ExpectSuccess(node_def, op);
83
84 EXPECT_EQ("{{node n}} = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
85
86 // Mismatching Op names.
87 NodeDef bad = node_def;
88 bad.set_op("Wrong");
89 ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;");
90
91 // Missing attr
92 bad = node_def;
93 bad.clear_attr();
94 ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;");
95
96 // Extra attr
97 bad = node_def;
98 AddNodeAttr("EXTRA", 17, &bad);
99 ExpectFailure(bad, op, "NodeDef mentions attr 'EXTRA' not in Op<name=In;");
100
101 // Attr has wrong type
102 bad = node_def;
103 bad.clear_attr();
104 AddNodeAttr("T", 17, &bad);
105 ExpectFailure(
106 bad, op,
107 "AttrValue had value with type 'int' when 'type' expected\n\t for attr "
108 "'T'\n\t; NodeDef: ");
109
110 // Wrong number of inputs
111 bad = node_def;
112 bad.add_input("b");
113 ExpectFailure(
114 bad, op,
115 "NodeDef expected inputs 'float' do not match 2 inputs specified;");
116
117 bad = node_def;
118 bad.clear_input();
119 ExpectFailure(
120 bad, op,
121 "NodeDef expected inputs 'float' do not match 0 inputs specified;");
122
123 // Control inputs must appear after data inputs
124 NodeDef good = node_def;
125 good.add_input("^b");
126 ExpectSuccess(node_def, op);
127
128 bad = node_def;
129 bad.clear_input();
130 bad.add_input("^b");
131 bad.add_input("a");
132 ExpectFailure(bad, op,
133 "Invalid argument: Non-control input 'a' after control input "
134 "in NodeDef:");
135
136 bad = node_def;
137 bad.add_input("^b:0");
138 ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:");
139 }
140
TEST(NodeDefUtilTest,Out)141 TEST(NodeDefUtilTest, Out) {
142 const OpDef op =
143 ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype"));
144 const NodeDef node_def = ToNodeDef(R"proto(
145 name:'n' op:'Out' attr { key:'T' value { type:DT_INT32 } }
146 )proto");
147 ExpectSuccess(node_def, op);
148
149 EXPECT_EQ("{{node n}} = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
150
151 // Non-number type.
152 NodeDef bad = node_def;
153 bad.clear_attr();
154 AddNodeAttr("T", DT_STRING, &bad);
155 ExpectFailure(bad, op,
156 "Value for attr 'T' of string is not in the list of allowed "
157 "values: float, double, int32, uint8, int16, int8, complex64, "
158 "int64, qint8, quint8, qint32, bfloat16, uint16, complex128, "
159 "half, uint32, uint64");
160 }
161
TEST(NodeDefUtilTest,Enum)162 TEST(NodeDefUtilTest, Enum) {
163 const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}"));
164 const NodeDef node_def = ToNodeDef(R"proto(
165 name:'n' op:'Enum' attr { key:'e' value { s:'apple' } }
166 )proto");
167 ExpectSuccess(node_def, op);
168
169 EXPECT_EQ("{{node n}} = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
170
171 NodeDef good = node_def;
172 good.clear_attr();
173 AddNodeAttr("e", "orange", &good);
174 ExpectSuccess(good, op);
175
176 // Non-allowed value.
177 NodeDef bad = node_def;
178 bad.clear_attr();
179 AddNodeAttr("e", "foo", &bad);
180 ExpectFailure(bad, op,
181 "Value for attr 'e' of \"foo\" is not in the list of allowed "
182 "values: \"apple\", \"orange\"");
183 }
184
TEST(NodeDefUtilTest,SameIn)185 TEST(NodeDefUtilTest, SameIn) {
186 const OpDef op = ToOpDef(OpDefBuilder("SameIn")
187 .Input("i: N * T")
188 .Attr("N: int >= 2")
189 .Attr("T: {float,double}"));
190 const NodeDef node_def = ToNodeDef(R"proto(
191 name:'n' op:'SameIn' input:'a' input:'b'
192 attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } }
193 )proto");
194 ExpectSuccess(node_def, op);
195
196 EXPECT_EQ("{{node n}} = SameIn[N=2, T=DT_DOUBLE](a, b)",
197 SummarizeNodeDef(node_def));
198
199 // Illegal type
200 NodeDef bad = ToNodeDef(R"proto(
201 name:'n' op:'SameIn' input:'a' input:'b'
202 attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } }
203 )proto");
204 ExpectFailure(bad, op,
205 "Value for attr 'T' of string is not in the list of allowed "
206 "values: float, double");
207
208 // Too few inputs
209 bad = ToNodeDef(R"proto(
210 name:'n' op:'SameIn' input:'a' input:'b'
211 attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } }
212 )proto");
213 ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2");
214 }
215
TEST(NodeDefUtilTest,AnyIn)216 TEST(NodeDefUtilTest, AnyIn) {
217 const OpDef op =
218 ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1"));
219
220 const NodeDef node_def = ToNodeDef(R"proto(
221 name:'n' op:'AnyIn' input:'a' input:'b'
222 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
223 )proto");
224 ExpectSuccess(node_def, op);
225
226 EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
227 SummarizeNodeDef(node_def));
228
229 const NodeDef bad = ToNodeDef(R"proto(
230 name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } }
231 )proto");
232 ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1");
233
234 // With proto3 semantics, an empty value {} is indistinguishable from a value
235 // with an empty list in it. So we simply expect to get a message complaining
236 // about empty list for value {}.
237 const NodeDef bad2 = ToNodeDef(R"proto(
238 name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } }
239 )proto");
240 ExpectFailure(bad2, op,
241 "Length for attr 'T' of 0 must be at least minimum 1");
242 }
243
TEST(NodeDefUtilTest,Device)244 TEST(NodeDefUtilTest, Device) {
245 const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
246 const NodeDef node_def1 =
247 ToNodeDef(std::move(NodeDefBuilder("d", &op_def1).Device("/cpu:17")));
248 ExpectSuccess(node_def1, op_def1);
249 EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()",
250 SummarizeNodeDef(node_def1));
251
252 const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
253 const NodeDef node_def2 = ToNodeDef(
254 std::move(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")));
255 ExpectSuccess(node_def2, op_def2);
256 EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()",
257 SummarizeNodeDef(node_def2));
258 }
259
ExpectValidSyntax(const NodeDef & good)260 void ExpectValidSyntax(const NodeDef& good) {
261 EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good))
262 << "NodeDef: " << SummarizeNodeDef(good);
263 }
264
ExpectInvalidSyntax(const NodeDef & bad,const string & message)265 void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
266 Status status = ValidateExternalNodeDefSyntax(bad);
267
268 ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad);
269
270 EXPECT_TRUE(errors::IsInvalidArgument(status))
271 << status << "; NodeDef: " << SummarizeNodeDef(bad);
272
273 EXPECT_TRUE(absl::StrContains(StringPiece(status.ToString()), message))
274 << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
275 << message;
276 }
277
TEST(NodeDefUtilTest,ValidSyntax)278 TEST(NodeDefUtilTest, ValidSyntax) {
279 const NodeDef node_def = ToNodeDef(R"proto(
280 name:'n' op:'AnyIn' input:'a' input:'b'
281 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
282 )proto");
283 ExpectValidSyntax(node_def);
284
285 const NodeDef node_def_namespace = ToNodeDef(R"proto(
286 name: 'n'
287 op: 'Project>AnyIn'
288 input: 'a'
289 input: 'b'
290 attr {
291 key: 'T'
292 value { list { type: [ DT_INT32, DT_STRING ] } }
293 }
294 )proto");
295 ExpectValidSyntax(node_def_namespace);
296
297 const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto(
298 name: 'n'
299 op: 'AnyIn'
300 input: 'a:0'
301 input: 'b:123'
302 attr {
303 key: 'T'
304 value { list { type: [ DT_INT32, DT_STRING ] } }
305 }
306 )proto");
307 ExpectValidSyntax(node_def_explicit_inputs);
308
309 EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
310 SummarizeNodeDef(node_def_explicit_inputs));
311
312 const NodeDef node_def_explicit_inputs_namespace = ToNodeDef(R"proto(
313 name: 'Project>n'
314 op: 'Project>AnyIn'
315 input: 'Project>a:0'
316 input: 'Project>b:123'
317 input: '^Project>c'
318 attr {
319 key: 'T'
320 value { list { type: [ DT_INT32, DT_STRING ] } }
321 }
322 )proto");
323 ExpectValidSyntax(node_def_explicit_inputs_namespace);
324
325 EXPECT_EQ(
326 "{{node Project>n}} = Project>AnyIn[T=[DT_INT32, DT_STRING]]"
327 "(Project>a:0, Project>b:123, ^Project>c)",
328 SummarizeNodeDef(node_def_explicit_inputs_namespace));
329
330 const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
331 name:'n' op:'AnyIn'
332 attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } }
333 )proto");
334 ExpectValidSyntax(node_def_partial_shape);
335
336 const NodeDef node_def_control_input = ToNodeDef(R"proto(
337 name:'n-' op:'AnyIn' input:'a' input:'^b'
338 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
339 )proto");
340 ExpectValidSyntax(node_def_control_input);
341
342 const NodeDef node_def_invalid_name = ToNodeDef(R"proto(
343 name:'n:0' op:'AnyIn' input:'a' input:'b'
344 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
345 )proto");
346 ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'");
347
348 const NodeDef node_def_internal_name = ToNodeDef(R"proto(
349 name:'_n' op:'AnyIn' input:'a' input:'b'
350 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
351 )proto");
352 ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
353
354 const NodeDef node_def_slash_in_name = ToNodeDef(R"proto(
355 name:'n\\' op:'AnyIn' input:'a' input:'b'
356 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
357 )proto");
358 ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'");
359
360 const NodeDef node_def_internal_input_name = ToNodeDef(R"proto(
361 name:'n' op:'AnyIn' input:'_a' input:'b'
362 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
363 )proto");
364 ExpectInvalidSyntax(node_def_internal_input_name,
365 "Illegal op input name '_a'");
366
367 const NodeDef node_def_input_name_slash = ToNodeDef(R"proto(
368 name:'n' op:'AnyIn' input:'a\\' input:'b'
369 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
370 )proto");
371 ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'");
372
373 const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto(
374 name:'n' op:'AnyIn' input:'a' input:'^b:0'
375 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
376 )proto");
377 ExpectInvalidSyntax(node_def_invalid_control_input_name,
378 "Illegal op input name '^b:0'");
379
380 const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto(
381 name:'n' op:'AnyIn' input:'a' input:'^b\\'
382 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
383 )proto");
384 ExpectInvalidSyntax(node_def_control_input_name_slash,
385 "Illegal op input name '^b\\'");
386
387 const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto(
388 name:'n' op:'AnyIn' input:'^a' input:'b'
389 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
390 )proto");
391 ExpectInvalidSyntax(node_def_data_input_after_control,
392 "All control inputs must follow all data inputs");
393
394 const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto(
395 name:'n' op:'AnyIn' input:'a:b' input:'b'
396 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
397 )proto");
398 ExpectInvalidSyntax(node_def_data_input_invalid_port,
399 "Illegal op input name 'a:b");
400
401 const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto(
402 name:'n' op:'AnyIn' input:'a:00' input:'b'
403 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
404 )proto");
405 ExpectInvalidSyntax(node_def_data_input_invalid_port2,
406 "Illegal op input name 'a:00");
407 }
408
TEST(InputTypesForNode,Simple)409 TEST(InputTypesForNode, Simple) {
410 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
411 .Input("a: float")
412 .Input("b: int32")
413 .Output("c: string")
414 .Output("d: bool"));
415 const NodeDef node_def = ToNodeDef(std::move(
416 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
417 DataTypeVector types;
418 EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
419 EXPECT_EQ(types[0], DT_FLOAT);
420 EXPECT_EQ(types[1], DT_INT32);
421
422 DataType type;
423 EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
424 EXPECT_EQ(type, DT_FLOAT);
425 EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
426 EXPECT_EQ(type, DT_INT32);
427 EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
428 }
429
TEST(OutputTypesForNode,Simple)430 TEST(OutputTypesForNode, Simple) {
431 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
432 .Input("a: float")
433 .Input("b: int32")
434 .Output("c: string")
435 .Output("d: bool"));
436 const NodeDef node_def = ToNodeDef(std::move(
437 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
438 DataTypeVector types;
439 EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
440 EXPECT_EQ(types[0], DT_STRING);
441 EXPECT_EQ(types[1], DT_BOOL);
442
443 DataType type;
444 EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
445 EXPECT_EQ(type, DT_STRING);
446 EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
447 EXPECT_EQ(type, DT_BOOL);
448 EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
449 }
450
TEST(OutputTypesForNode,LargeOutput)451 TEST(OutputTypesForNode, LargeOutput) {
452 const OpDef op_def = ToOpDef(OpDefBuilder("TestSplitOp")
453 .Input("value: int64")
454 .Output("output: num_split * int64")
455 .Attr("num_split: int >= 1"));
456 int64 num_split = 1000000000000;
457 const NodeDef node_def =
458 ToNodeDef(std::move(NodeDefBuilder("test_split_op", &op_def)
459 .Input(FakeInput())
460 .Attr("num_split", num_split)));
461 DataTypeVector types;
462 EXPECT_FALSE(OutputTypesForNode(node_def, op_def, &types).ok());
463 }
464
TEST(OutputTypesForNode_AttrSliceOverload,Simple)465 TEST(OutputTypesForNode_AttrSliceOverload, Simple) {
466 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
467 .Input("a: float")
468 .Input("b: int32")
469 .Output("c: string")
470 .Output("d: bool"));
471 const AttrSlice attr_slice =
472 AttrSlice(ToNodeDef(std::move(NodeDefBuilder("simple", &op_def)
473 .Input(FakeInput())
474 .Input(FakeInput()))));
475 DataTypeVector types;
476 EXPECT_TRUE(OutputTypesForNode(attr_slice, op_def, &types).ok());
477 EXPECT_EQ(types[0], DT_STRING);
478 EXPECT_EQ(types[1], DT_BOOL);
479 }
480
TEST(NameRangesForNodeTest,Simple)481 TEST(NameRangesForNodeTest, Simple) {
482 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
483 .Input("a: float")
484 .Input("b: int32")
485 .Output("c: string")
486 .Output("d: bool"));
487 NameRangeMap inputs, outputs;
488 const NodeDef node_def = ToNodeDef(std::move(
489 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
490 TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
491 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
492 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
493
494 EXPECT_EQ("{{node simple}} = Simple[](a, b)", SummarizeNodeDef(node_def));
495
496 OpDef bad_op_def = op_def;
497 bad_op_def.mutable_input_arg(0)->clear_type();
498 EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok());
499 }
500
TEST(NameRangesForNodeTest,Polymorphic)501 TEST(NameRangesForNodeTest, Polymorphic) {
502 const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic")
503 .Input("a: T")
504 .Input("b: T")
505 .Output("c: T")
506 .Attr("T: type"));
507 NameRangeMap inputs, outputs;
508 const NodeDef node_def1 =
509 ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
510 .Input(FakeInput(DT_INT32))
511 .Input(FakeInput(DT_INT32))));
512 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
513 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
514 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
515 EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)",
516 SummarizeNodeDef(node_def1));
517
518 const NodeDef node_def2 =
519 ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
520 .Input(FakeInput(DT_BOOL))
521 .Input(FakeInput(DT_BOOL))));
522 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
523 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
524 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
525 EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_BOOL](a, b)",
526 SummarizeNodeDef(node_def2));
527 }
528
TEST(NameRangesForNodeTest,NRepeats)529 TEST(NameRangesForNodeTest, NRepeats) {
530 const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats")
531 .Input("a: N * int32")
532 .Input("b: N * T")
533 .Output("c: T")
534 .Output("d: N * string")
535 .Output("e: M * bool")
536 .Attr("N: int")
537 .Attr("M: int")
538 .Attr("T: type"));
539 NameRangeMap inputs, outputs;
540 const NodeDef node_def1 =
541 ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
542 .Input(FakeInput(4, DT_INT32))
543 .Input(FakeInput(4, DT_FLOAT))
544 .Attr("M", 3)));
545 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
546 EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
547 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
548 outputs);
549 EXPECT_EQ(
550 "{{node nr}} = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, "
551 "b:2, b:3)",
552 SummarizeNodeDef(node_def1));
553
554 const NodeDef node_def2 =
555 ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
556 .Input(FakeInput(2, DT_INT32))
557 .Input(FakeInput(2, DT_DOUBLE))
558 .Attr("M", 7)));
559 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
560 EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
561 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
562 outputs);
563 EXPECT_EQ("{{node nr}} = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
564 SummarizeNodeDef(node_def2));
565
566 NodeDef bad_node_def = node_def2;
567 bad_node_def.clear_attr();
568 EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
569 }
570
TEST(NameRangesForNodeTest,TypeList)571 TEST(NameRangesForNodeTest, TypeList) {
572 const OpDef op_def = ToOpDef(OpDefBuilder("TypeList")
573 .Input("a: T1")
574 .Input("b: T2")
575 .Output("c: T2")
576 .Output("d: T3")
577 .Output("e: T1")
578 .Attr("T1: list(type)")
579 .Attr("T2: list(type)")
580 .Attr("T3: list(type)"));
581 NameRangeMap inputs, outputs;
582 const NodeDef node_def1 =
583 ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
584 .Input(FakeInput({DT_BOOL, DT_FLOAT}))
585 .Input(FakeInput(4, DT_FLOAT))
586 .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})));
587 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
588 EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
589 EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
590 outputs);
591 EXPECT_EQ(
592 "{{node tl}} = TypeList[T1=[DT_BOOL, DT_FLOAT],"
593 " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
594 " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
595 SummarizeNodeDef(node_def1));
596
597 const NodeDef node_def2 =
598 ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
599 .Input(FakeInput(7, DT_INT32))
600 .Input(FakeInput({DT_DOUBLE}))
601 .Attr("T3", {DT_DOUBLE, DT_STRING})));
602 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
603 EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
604 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
605 outputs);
606 EXPECT_EQ(
607 "{{node tl}} = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, "
608 "DT_INT32,"
609 " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
610 "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
611 SummarizeNodeDef(node_def2));
612
613 NodeDef bad_node_def = node_def2;
614 bad_node_def.clear_attr();
615 EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
616 }
617
TEST(AddPrefixAndSuffixToNode,Enter)618 TEST(AddPrefixAndSuffixToNode, Enter) {
619 NodeDef node_def;
620 node_def.set_name("enter");
621 node_def.set_op("Enter");
622 AddNodeAttr("frame_name", "test_frame", &node_def);
623 const string prefix = "prefix/";
624 const string suffix = "/suffix";
625 TF_ASSERT_OK(AddPrefixAndSuffixToNode(prefix, suffix, &node_def));
626 EXPECT_EQ("prefix/enter/suffix", node_def.name());
627 string frame_name;
628 TF_ASSERT_OK(GetNodeAttr(node_def, "frame_name", &frame_name));
629 EXPECT_EQ("prefix/test_frame/suffix", frame_name);
630 }
631
TEST(MaybeAddPrefixToColocationConstraints,Basic)632 TEST(MaybeAddPrefixToColocationConstraints, Basic) {
633 NodeDef node_def;
634 node_def.set_name("Identity");
635 node_def.set_op("Identity");
636 AddNodeAttr(kColocationAttrName,
637 {strings::StrCat(kColocationGroupPrefix, "Node1"),
638 strings::StrCat(kColocationGroupPrefix, "Node2"),
639 strings::StrCat(kColocationGroupPrefix, "Node3")},
640 &node_def);
641
642 std::unordered_set<string> match;
643 match.insert("Node1");
644 match.insert("Node3");
645 TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
646 std::vector<string> coloc_constraints;
647 TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints));
648 EXPECT_EQ(
649 coloc_constraints,
650 std::vector<string>({"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"}));
651 }
652
TEST(MaybeAddPrefixToColocationConstraints,NoConstraints)653 TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) {
654 NodeDef node_def;
655 node_def.set_name("Identity");
656 node_def.set_op("Identity");
657
658 std::unordered_set<string> match;
659 match.insert("Node1");
660 match.insert("Node3");
661 TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
662 EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName));
663 }
664
TEST(FormatNodeForErrorTest,Node)665 TEST(FormatNodeForErrorTest, Node) {
666 Graph g(OpRegistry::Global());
667 Node* node;
668 TF_CHECK_OK(NodeBuilder("enter", "NoOp").Finalize(&g, &node));
669 EXPECT_EQ("{{node enter}}", FormatNodeForError(*node));
670 }
671
TEST(FormatNodeForErrorTest,NodeDef)672 TEST(FormatNodeForErrorTest, NodeDef) {
673 NodeDef node_def;
674 node_def.set_name("enter");
675 node_def.set_op("Enter");
676 AddNodeAttr("frame_name", "test_frame", &node_def);
677 EXPECT_EQ("{{node enter}}", FormatNodeDefForError(node_def));
678 }
679
TEST(AttachDef,AllowMultipleFormattedNode)680 TEST(AttachDef, AllowMultipleFormattedNode) {
681 NodeDef a;
682 a.set_name("a");
683 NodeDef b;
684 b.set_name("b");
685 Status s = Status(error::CANCELLED, "Error");
686 Status s2 = AttachDef(s, a, true);
687 EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
688 Status s3 = AttachDef(s2, b, true);
689 EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[{{node b}}]]", s3.error_message());
690 }
691
TEST(AttachDef,DisallowMultipleFormattedNode)692 TEST(AttachDef, DisallowMultipleFormattedNode) {
693 NodeDef a;
694 a.set_name("a");
695 NodeDef b;
696 b.set_name("b");
697 Status s = Status(error::CANCELLED, "Error");
698 Status s2 = AttachDef(s, a, false);
699 EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
700 Status s3 = AttachDef(s2, b, false);
701 EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[b]]", s3.error_message());
702 }
703
704 } // namespace
705 } // namespace tensorflow
706