• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"  // NOLINT
19 #include "tensorflow/core/framework/fake_input.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30 
31 namespace tensorflow {
32 namespace {
33 
ToOpDef(const OpDefBuilder & builder)34 OpDef ToOpDef(const OpDefBuilder& builder) {
35   OpRegistrationData op_reg_data;
36   TF_EXPECT_OK(builder.Finalize(&op_reg_data));
37   return op_reg_data.op_def;
38 }
39 
ToNodeDef(const string & text)40 NodeDef ToNodeDef(const string& text) {
41   NodeDef node_def;
42   EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
43   return node_def;
44 }
45 
ToNodeDef(NodeDefBuilder && builder)46 NodeDef ToNodeDef(NodeDefBuilder&& builder) {
47   NodeDef node_def;
48   TF_EXPECT_OK(builder.Finalize(&node_def));
49   return node_def;
50 }
51 
ExpectSuccess(const NodeDef & good,const OpDef & op_def)52 void ExpectSuccess(const NodeDef& good, const OpDef& op_def) {
53   EXPECT_EQ(OkStatus(), ValidateNodeDef(good, op_def))
54       << "NodeDef: " << SummarizeNodeDef(good)
55       << "; OpDef: " << SummarizeOpDef(op_def);
56 }
57 
ExpectFailure(const NodeDef & bad,const OpDef & op_def,const string & message)58 void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
59                    const string& message) {
60   Status status = ValidateNodeDef(bad, op_def);
61 
62   EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad)
63                             << "; OpDef: " << SummarizeOpDef(op_def);
64   if (status.ok()) return;
65 
66   EXPECT_TRUE(errors::IsInvalidArgument(status))
67       << status << "; NodeDef: " << SummarizeNodeDef(bad)
68       << "; OpDef: " << SummarizeOpDef(op_def);
69 
70   LOG(INFO) << "Message: " << status.error_message();
71   EXPECT_TRUE(absl::StrContains(status.ToString(), message))
72       << "NodeDef: " << SummarizeNodeDef(bad)
73       << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
74       << "\nDoes not contain: " << message;
75 }
76 
TEST(NodeDefUtilTest,In)77 TEST(NodeDefUtilTest, In) {
78   const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type"));
79   const NodeDef node_def = ToNodeDef(R"pb(
80     name: 'n'
81     op: 'In'
82     input: 'a'
83     attr {
84       key: 'T'
85       value { type: DT_FLOAT }
86     }
87   )pb");
88   ExpectSuccess(node_def, op);
89 
90   EXPECT_EQ("{{node n}} = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
91 
92   // Mismatching Op names.
93   NodeDef bad = node_def;
94   bad.set_op("Wrong");
95   ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;");
96 
97   // Missing attr
98   bad = node_def;
99   bad.clear_attr();
100   ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;");
101 
102   // Attr has wrong type
103   bad = node_def;
104   bad.clear_attr();
105   AddNodeAttr("T", 17, &bad);
106   ExpectFailure(
107       bad, op,
108       "AttrValue had value with type 'int' when 'type' expected\n\t for attr "
109       "'T'\n\t; NodeDef: ");
110 
111   // Wrong number of inputs
112   bad = node_def;
113   bad.add_input("b");
114   ExpectFailure(
115       bad, op,
116       "NodeDef expected inputs 'float' do not match 2 inputs specified;");
117 
118   bad = node_def;
119   bad.clear_input();
120   ExpectFailure(
121       bad, op,
122       "NodeDef expected inputs 'float' do not match 0 inputs specified;");
123 
124   // Control inputs must appear after data inputs
125   NodeDef good = node_def;
126   good.add_input("^b");
127   ExpectSuccess(node_def, op);
128 
129   bad = node_def;
130   bad.clear_input();
131   bad.add_input("^b");
132   bad.add_input("a");
133   ExpectFailure(bad, op,
134                 "Non-control input 'a' after control input "
135                 "in NodeDef:");
136 
137   bad = node_def;
138   bad.add_input("^b:0");
139   ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:");
140 }
141 
TEST(NodeDefUtilTest,Out)142 TEST(NodeDefUtilTest, Out) {
143   const OpDef op =
144       ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype"));
145   const NodeDef node_def = ToNodeDef(R"pb(
146     name: 'n'
147     op: 'Out'
148     attr {
149       key: 'T'
150       value { type: DT_INT32 }
151     }
152   )pb");
153   ExpectSuccess(node_def, op);
154 
155   EXPECT_EQ("{{node n}} = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
156 
157   // Non-number type.
158   NodeDef bad = node_def;
159   bad.clear_attr();
160   AddNodeAttr("T", DT_STRING, &bad);
161   ExpectFailure(bad, op,
162                 "Value for attr 'T' of string is not in the list of allowed "
163                 "values: float, double, int32, uint8, int16, int8, complex64, "
164                 "int64, qint8, quint8, qint32, bfloat16, uint16, complex128, "
165                 "half, uint32, uint64");
166 }
167 
TEST(NodeDefUtilTest,Enum)168 TEST(NodeDefUtilTest, Enum) {
169   const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}"));
170   const NodeDef node_def = ToNodeDef(R"pb(
171     name: 'n'
172     op: 'Enum'
173     attr {
174       key: 'e'
175       value { s: 'apple' }
176     }
177   )pb");
178   ExpectSuccess(node_def, op);
179 
180   EXPECT_EQ("{{node n}} = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
181 
182   NodeDef good = node_def;
183   good.clear_attr();
184   AddNodeAttr("e", "orange", &good);
185   ExpectSuccess(good, op);
186 
187   // Non-allowed value.
188   NodeDef bad = node_def;
189   bad.clear_attr();
190   AddNodeAttr("e", "foo", &bad);
191   ExpectFailure(bad, op,
192                 "Value for attr 'e' of \"foo\" is not in the list of allowed "
193                 "values: \"apple\", \"orange\"");
194 }
195 
TEST(NodeDefUtilTest,SameIn)196 TEST(NodeDefUtilTest, SameIn) {
197   const OpDef op = ToOpDef(OpDefBuilder("SameIn")
198                                .Input("i: N * T")
199                                .Attr("N: int >= 2")
200                                .Attr("T: {float,double}"));
201   const NodeDef node_def = ToNodeDef(R"pb(
202     name: 'n'
203     op: 'SameIn'
204     input: 'a'
205     input: 'b'
206     attr {
207       key: 'N'
208       value { i: 2 }
209     }
210     attr {
211       key: 'T'
212       value { type: DT_DOUBLE }
213     }
214   )pb");
215   ExpectSuccess(node_def, op);
216 
217   EXPECT_EQ("{{node n}} = SameIn[N=2, T=DT_DOUBLE](a, b)",
218             SummarizeNodeDef(node_def));
219 
220   // Illegal type
221   NodeDef bad = ToNodeDef(R"pb(
222     name: 'n'
223     op: 'SameIn'
224     input: 'a'
225     input: 'b'
226     attr {
227       key: 'N'
228       value { i: 2 }
229     }
230     attr {
231       key: 'T'
232       value { type: DT_STRING }
233     }
234   )pb");
235   ExpectFailure(bad, op,
236                 "Value for attr 'T' of string is not in the list of allowed "
237                 "values: float, double");
238 
239   // Too few inputs
240   bad = ToNodeDef(R"pb(
241     name: 'n'
242     op: 'SameIn'
243     input: 'a'
244     input: 'b'
245     attr {
246       key: 'N'
247       value { i: 1 }
248     }
249     attr {
250       key: 'T'
251       value { type: DT_FLOAT }
252     }
253   )pb");
254   ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2");
255 }
256 
TEST(NodeDefUtilTest,AnyIn)257 TEST(NodeDefUtilTest, AnyIn) {
258   const OpDef op =
259       ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1"));
260 
261   const NodeDef node_def = ToNodeDef(R"pb(
262     name: 'n'
263     op: 'AnyIn'
264     input: 'a'
265     input: 'b'
266     attr {
267       key: 'T'
268       value { list { type: [ DT_INT32, DT_STRING ] } }
269     }
270   )pb");
271   ExpectSuccess(node_def, op);
272 
273   EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
274             SummarizeNodeDef(node_def));
275 
276   const NodeDef bad = ToNodeDef(R"pb(
277     name: 'n'
278     op: 'AnyIn'
279     input: 'a'
280     attr {
281       key: 'T'
282       value { list {} }
283     }
284   )pb");
285   ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1");
286 
287   // With proto3 semantics, an empty value {} is indistinguishable from a value
288   // with an empty list in it. So we simply expect to get a message complaining
289   // about empty list for value {}.
290   const NodeDef bad2 = ToNodeDef(R"pb(
291     name: 'n'
292     op: 'AnyIn'
293     input: 'a'
294     attr {
295       key: 'T'
296       value {}
297     }
298   )pb");
299   ExpectFailure(bad2, op,
300                 "Length for attr 'T' of 0 must be at least minimum 1");
301 }
302 
TEST(NodeDefUtilTest,Device)303 TEST(NodeDefUtilTest, Device) {
304   const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
305   const NodeDef node_def1 =
306       ToNodeDef(std::move(NodeDefBuilder("d", &op_def1).Device("/cpu:17")));
307   ExpectSuccess(node_def1, op_def1);
308   EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()",
309             SummarizeNodeDef(node_def1));
310 
311   const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
312   const NodeDef node_def2 = ToNodeDef(
313       std::move(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")));
314   ExpectSuccess(node_def2, op_def2);
315   EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()",
316             SummarizeNodeDef(node_def2));
317 }
318 
ExpectValidSyntax(const NodeDef & good)319 void ExpectValidSyntax(const NodeDef& good) {
320   EXPECT_EQ(OkStatus(), ValidateExternalNodeDefSyntax(good))
321       << "NodeDef: " << SummarizeNodeDef(good);
322 }
323 
ExpectInvalidSyntax(const NodeDef & bad,const string & message)324 void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
325   Status status = ValidateExternalNodeDefSyntax(bad);
326 
327   ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad);
328 
329   EXPECT_TRUE(errors::IsInvalidArgument(status))
330       << status << "; NodeDef: " << SummarizeNodeDef(bad);
331 
332   EXPECT_TRUE(absl::StrContains(StringPiece(status.ToString()), message))
333       << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
334       << message;
335 }
336 
TEST(NodeDefUtilTest,ValidSyntax)337 TEST(NodeDefUtilTest, ValidSyntax) {
338   const NodeDef node_def = ToNodeDef(R"pb(
339     name: 'n'
340     op: 'AnyIn'
341     input: 'a'
342     input: 'b'
343     attr {
344       key: 'T'
345       value { list { type: [ DT_INT32, DT_STRING ] } }
346     }
347   )pb");
348   ExpectValidSyntax(node_def);
349 
350   const NodeDef node_def_namespace = ToNodeDef(R"pb(
351     name: 'n'
352     op: 'Project>AnyIn'
353     input: 'a'
354     input: 'b'
355     attr {
356       key: 'T'
357       value { list { type: [ DT_INT32, DT_STRING ] } }
358     }
359   )pb");
360   ExpectValidSyntax(node_def_namespace);
361 
362   const NodeDef node_def_explicit_inputs = ToNodeDef(R"pb(
363     name: 'n'
364     op: 'AnyIn'
365     input: 'a:0'
366     input: 'b:123'
367     attr {
368       key: 'T'
369       value { list { type: [ DT_INT32, DT_STRING ] } }
370     }
371   )pb");
372   ExpectValidSyntax(node_def_explicit_inputs);
373 
374   EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
375             SummarizeNodeDef(node_def_explicit_inputs));
376 
377   const NodeDef node_def_explicit_inputs_namespace = ToNodeDef(R"pb(
378     name: 'Project>n'
379     op: 'Project>AnyIn'
380     input: 'Project>a:0'
381     input: 'Project>b:123'
382     input: '^Project>c'
383     attr {
384       key: 'T'
385       value { list { type: [ DT_INT32, DT_STRING ] } }
386     }
387   )pb");
388   ExpectValidSyntax(node_def_explicit_inputs_namespace);
389 
390   EXPECT_EQ(
391       "{{node Project>n}} = Project>AnyIn[T=[DT_INT32, DT_STRING]]"
392       "(Project>a:0, Project>b:123, ^Project>c)",
393       SummarizeNodeDef(node_def_explicit_inputs_namespace));
394 
395   const NodeDef node_def_partial_shape = ToNodeDef(R"pb(
396     name: 'n'
397     op: 'AnyIn'
398     attr {
399       key: 'shp'
400       value {
401         shape {
402           dim { size: -1 }
403           dim { size: 0 }
404         }
405       }
406     }
407   )pb");
408   ExpectValidSyntax(node_def_partial_shape);
409 
410   const NodeDef node_def_control_input = ToNodeDef(R"pb(
411     name: 'n-'
412     op: 'AnyIn'
413     input: 'a'
414     input: '^b'
415     attr {
416       key: 'T'
417       value { list { type: [ DT_INT32, DT_STRING ] } }
418     }
419   )pb");
420   ExpectValidSyntax(node_def_control_input);
421 
422   const NodeDef node_def_invalid_name = ToNodeDef(R"pb(
423     name: 'n:0'
424     op: 'AnyIn'
425     input: 'a'
426     input: 'b'
427     attr {
428       key: 'T'
429       value { list { type: [ DT_INT32, DT_STRING ] } }
430     }
431   )pb");
432   ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'");
433 
434   const NodeDef node_def_internal_name = ToNodeDef(R"pb(
435     name: '_n'
436     op: 'AnyIn'
437     input: 'a'
438     input: 'b'
439     attr {
440       key: 'T'
441       value { list { type: [ DT_INT32, DT_STRING ] } }
442     }
443   )pb");
444   ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
445 
446   const NodeDef node_def_slash_in_name = ToNodeDef(R"pb(
447     name: 'n\\'
448     op: 'AnyIn'
449     input: 'a'
450     input: 'b'
451     attr {
452       key: 'T'
453       value { list { type: [ DT_INT32, DT_STRING ] } }
454     }
455   )pb");
456   ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'");
457 
458   const NodeDef node_def_internal_input_name = ToNodeDef(R"pb(
459     name: 'n'
460     op: 'AnyIn'
461     input: '_a'
462     input: 'b'
463     attr {
464       key: 'T'
465       value { list { type: [ DT_INT32, DT_STRING ] } }
466     }
467   )pb");
468   ExpectInvalidSyntax(node_def_internal_input_name,
469                       "Illegal op input name '_a'");
470 
471   const NodeDef node_def_input_name_slash = ToNodeDef(R"pb(
472     name: 'n'
473     op: 'AnyIn'
474     input: 'a\\'
475     input: 'b'
476     attr {
477       key: 'T'
478       value { list { type: [ DT_INT32, DT_STRING ] } }
479     }
480   )pb");
481   ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'");
482 
483   const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"pb(
484     name: 'n'
485     op: 'AnyIn'
486     input: 'a'
487     input: '^b:0'
488     attr {
489       key: 'T'
490       value { list { type: [ DT_INT32, DT_STRING ] } }
491     }
492   )pb");
493   ExpectInvalidSyntax(node_def_invalid_control_input_name,
494                       "Illegal op input name '^b:0'");
495 
496   const NodeDef node_def_control_input_name_slash = ToNodeDef(R"pb(
497     name: 'n'
498     op: 'AnyIn'
499     input: 'a'
500     input: '^b\\'
501     attr {
502       key: 'T'
503       value { list { type: [ DT_INT32, DT_STRING ] } }
504     }
505   )pb");
506   ExpectInvalidSyntax(node_def_control_input_name_slash,
507                       "Illegal op input name '^b\\'");
508 
509   const NodeDef node_def_data_input_after_control = ToNodeDef(R"pb(
510     name: 'n'
511     op: 'AnyIn'
512     input: '^a'
513     input: 'b'
514     attr {
515       key: 'T'
516       value { list { type: [ DT_INT32, DT_STRING ] } }
517     }
518   )pb");
519   ExpectInvalidSyntax(node_def_data_input_after_control,
520                       "All control inputs must follow all data inputs");
521 
522   const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"pb(
523     name: 'n'
524     op: 'AnyIn'
525     input: 'a:b'
526     input: 'b'
527     attr {
528       key: 'T'
529       value { list { type: [ DT_INT32, DT_STRING ] } }
530     }
531   )pb");
532   ExpectInvalidSyntax(node_def_data_input_invalid_port,
533                       "Illegal op input name 'a:b");
534 
535   const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"pb(
536     name: 'n'
537     op: 'AnyIn'
538     input: 'a:00'
539     input: 'b'
540     attr {
541       key: 'T'
542       value { list { type: [ DT_INT32, DT_STRING ] } }
543     }
544   )pb");
545   ExpectInvalidSyntax(node_def_data_input_invalid_port2,
546                       "Illegal op input name 'a:00");
547 }
548 
TEST(InputTypesForNode,Simple)549 TEST(InputTypesForNode, Simple) {
550   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
551                                    .Input("a: float")
552                                    .Input("b: int32")
553                                    .Output("c: string")
554                                    .Output("d: bool"));
555   const NodeDef node_def = ToNodeDef(std::move(
556       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
557   DataTypeVector types;
558   EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
559   EXPECT_EQ(types[0], DT_FLOAT);
560   EXPECT_EQ(types[1], DT_INT32);
561 
562   DataType type;
563   EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
564   EXPECT_EQ(type, DT_FLOAT);
565   EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
566   EXPECT_EQ(type, DT_INT32);
567   EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
568 }
569 
TEST(OutputTypesForNode,Simple)570 TEST(OutputTypesForNode, Simple) {
571   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
572                                    .Input("a: float")
573                                    .Input("b: int32")
574                                    .Output("c: string")
575                                    .Output("d: bool"));
576   const NodeDef node_def = ToNodeDef(std::move(
577       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
578   DataTypeVector types;
579   EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
580   EXPECT_EQ(types[0], DT_STRING);
581   EXPECT_EQ(types[1], DT_BOOL);
582 
583   DataType type;
584   EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
585   EXPECT_EQ(type, DT_STRING);
586   EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
587   EXPECT_EQ(type, DT_BOOL);
588   EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
589 }
590 
TEST(OutputTypesForNode,LargeOutput)591 TEST(OutputTypesForNode, LargeOutput) {
592   const OpDef op_def = ToOpDef(OpDefBuilder("TestSplitOp")
593                                    .Input("value: int64")
594                                    .Output("output: num_split * int64")
595                                    .Attr("num_split: int >= 1"));
596   int64_t num_split = 1000000000000;
597   const NodeDef node_def =
598       ToNodeDef(std::move(NodeDefBuilder("test_split_op", &op_def)
599                               .Input(FakeInput())
600                               .Attr("num_split", num_split)));
601   DataTypeVector types;
602   EXPECT_FALSE(OutputTypesForNode(node_def, op_def, &types).ok());
603 }
604 
TEST(OutputTypesForNode_AttrSliceOverload,Simple)605 TEST(OutputTypesForNode_AttrSliceOverload, Simple) {
606   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
607                                    .Input("a: float")
608                                    .Input("b: int32")
609                                    .Output("c: string")
610                                    .Output("d: bool"));
611   const AttrSlice attr_slice =
612       AttrSlice(ToNodeDef(std::move(NodeDefBuilder("simple", &op_def)
613                                         .Input(FakeInput())
614                                         .Input(FakeInput()))));
615   DataTypeVector types;
616   EXPECT_TRUE(OutputTypesForNode(attr_slice, op_def, &types).ok());
617   EXPECT_EQ(types[0], DT_STRING);
618   EXPECT_EQ(types[1], DT_BOOL);
619 }
620 
TEST(NameRangesForNodeTest,Simple)621 TEST(NameRangesForNodeTest, Simple) {
622   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
623                                    .Input("a: float")
624                                    .Input("b: int32")
625                                    .Output("c: string")
626                                    .Output("d: bool"));
627   NameRangeMap inputs, outputs;
628   const NodeDef node_def = ToNodeDef(std::move(
629       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
630   TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
631   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
632   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
633 
634   EXPECT_EQ("{{node simple}} = Simple[](a, b)", SummarizeNodeDef(node_def));
635 
636   OpDef bad_op_def = op_def;
637   bad_op_def.mutable_input_arg(0)->clear_type();
638   EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok());
639 }
640 
TEST(NameRangesForNodeTest,Polymorphic)641 TEST(NameRangesForNodeTest, Polymorphic) {
642   const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic")
643                                    .Input("a: T")
644                                    .Input("b: T")
645                                    .Output("c: T")
646                                    .Attr("T: type"));
647   NameRangeMap inputs, outputs;
648   const NodeDef node_def1 =
649       ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
650                               .Input(FakeInput(DT_INT32))
651                               .Input(FakeInput(DT_INT32))));
652   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
653   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
654   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
655   EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)",
656             SummarizeNodeDef(node_def1));
657 
658   const NodeDef node_def2 =
659       ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
660                               .Input(FakeInput(DT_BOOL))
661                               .Input(FakeInput(DT_BOOL))));
662   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
663   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
664   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
665   EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_BOOL](a, b)",
666             SummarizeNodeDef(node_def2));
667 }
668 
TEST(NameRangesForNodeTest,NRepeats)669 TEST(NameRangesForNodeTest, NRepeats) {
670   const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats")
671                                    .Input("a: N * int32")
672                                    .Input("b: N * T")
673                                    .Output("c: T")
674                                    .Output("d: N * string")
675                                    .Output("e: M * bool")
676                                    .Attr("N: int")
677                                    .Attr("M: int")
678                                    .Attr("T: type"));
679   NameRangeMap inputs, outputs;
680   const NodeDef node_def1 =
681       ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
682                               .Input(FakeInput(4, DT_INT32))
683                               .Input(FakeInput(4, DT_FLOAT))
684                               .Attr("M", 3)));
685   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
686   EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
687   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
688             outputs);
689   EXPECT_EQ(
690       "{{node nr}} = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, "
691       "b:2, b:3)",
692       SummarizeNodeDef(node_def1));
693 
694   const NodeDef node_def2 =
695       ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
696                               .Input(FakeInput(2, DT_INT32))
697                               .Input(FakeInput(2, DT_DOUBLE))
698                               .Attr("M", 7)));
699   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
700   EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
701   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
702             outputs);
703   EXPECT_EQ("{{node nr}} = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
704             SummarizeNodeDef(node_def2));
705 
706   NodeDef bad_node_def = node_def2;
707   bad_node_def.clear_attr();
708   EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
709 }
710 
TEST(NameRangesForNodeTest,TypeList)711 TEST(NameRangesForNodeTest, TypeList) {
712   const OpDef op_def = ToOpDef(OpDefBuilder("TypeList")
713                                    .Input("a: T1")
714                                    .Input("b: T2")
715                                    .Output("c: T2")
716                                    .Output("d: T3")
717                                    .Output("e: T1")
718                                    .Attr("T1: list(type)")
719                                    .Attr("T2: list(type)")
720                                    .Attr("T3: list(type)"));
721   NameRangeMap inputs, outputs;
722   const NodeDef node_def1 =
723       ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
724                               .Input(FakeInput({DT_BOOL, DT_FLOAT}))
725                               .Input(FakeInput(4, DT_FLOAT))
726                               .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})));
727   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
728   EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
729   EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
730             outputs);
731   EXPECT_EQ(
732       "{{node tl}} = TypeList[T1=[DT_BOOL, DT_FLOAT],"
733       " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
734       " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
735       SummarizeNodeDef(node_def1));
736 
737   const NodeDef node_def2 =
738       ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
739                               .Input(FakeInput(7, DT_INT32))
740                               .Input(FakeInput({DT_DOUBLE}))
741                               .Attr("T3", {DT_DOUBLE, DT_STRING})));
742   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
743   EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
744   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
745             outputs);
746   EXPECT_EQ(
747       "{{node tl}} = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, "
748       "DT_INT32,"
749       " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
750       "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
751       SummarizeNodeDef(node_def2));
752 
753   NodeDef bad_node_def = node_def2;
754   bad_node_def.clear_attr();
755   EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
756 }
757 
TEST(AddPrefixAndSuffixToNode,Enter)758 TEST(AddPrefixAndSuffixToNode, Enter) {
759   NodeDef node_def;
760   node_def.set_name("enter");
761   node_def.set_op("Enter");
762   AddNodeAttr("frame_name", "test_frame", &node_def);
763   const string prefix = "prefix/";
764   const string suffix = "/suffix";
765   TF_ASSERT_OK(AddPrefixAndSuffixToNode(prefix, suffix, &node_def));
766   EXPECT_EQ("prefix/enter/suffix", node_def.name());
767   string frame_name;
768   TF_ASSERT_OK(GetNodeAttr(node_def, "frame_name", &frame_name));
769   EXPECT_EQ("prefix/test_frame/suffix", frame_name);
770 }
771 
TEST(MaybeAddPrefixToColocationConstraints,Basic)772 TEST(MaybeAddPrefixToColocationConstraints, Basic) {
773   NodeDef node_def;
774   node_def.set_name("Identity");
775   node_def.set_op("Identity");
776   AddNodeAttr(kColocationAttrName,
777               {strings::StrCat(kColocationGroupPrefix, "Node1"),
778                strings::StrCat(kColocationGroupPrefix, "Node2"),
779                strings::StrCat(kColocationGroupPrefix, "Node3")},
780               &node_def);
781 
782   std::unordered_set<string> match;
783   match.insert("Node1");
784   match.insert("Node3");
785   TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
786   std::vector<string> coloc_constraints;
787   TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints));
788   EXPECT_EQ(
789       coloc_constraints,
790       std::vector<string>({"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"}));
791 }
792 
TEST(MaybeAddPrefixToColocationConstraints,NoConstraints)793 TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) {
794   NodeDef node_def;
795   node_def.set_name("Identity");
796   node_def.set_op("Identity");
797 
798   std::unordered_set<string> match;
799   match.insert("Node1");
800   match.insert("Node3");
801   TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
802   EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName));
803 }
804 
TEST(MaybeUpdateColocationConstraintsWithMap,Basic)805 TEST(MaybeUpdateColocationConstraintsWithMap, Basic) {
806   NodeDef node_def;
807   node_def.set_name("Identity");
808   node_def.set_op("Identity");
809   AddNodeAttr(kColocationAttrName,
810               {strings::StrCat(kColocationGroupPrefix, "Node1"),
811                strings::StrCat(kColocationGroupPrefix, "Node2"),
812                strings::StrCat(kColocationGroupPrefix, "Node3")},
813               &node_def);
814 
815   std::map<absl::string_view, absl::string_view> node_map;
816   node_map["Node1"] = "Node4";
817   node_map["Invalid"] = "Node5";
818   TF_ASSERT_OK(MaybeUpdateColocationConstraintsWithMap(node_map, &node_def));
819   std::vector<string> coloc_constraints;
820   TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints));
821   EXPECT_EQ(coloc_constraints,
822             std::vector<string>({"loc:@Node4", "loc:@Node2", "loc:@Node3"}));
823 }
824 
TEST(MaybeUpdateColocationConstraintsWithMap,NoConstraints)825 TEST(MaybeUpdateColocationConstraintsWithMap, NoConstraints) {
826   NodeDef node_def;
827   node_def.set_name("Identity");
828   node_def.set_op("Identity");
829 
830   std::map<absl::string_view, absl::string_view> node_map;
831   node_map["Node1"] = "Node4";
832   node_map["Invalid"] = "Node5";
833   TF_ASSERT_OK(MaybeUpdateColocationConstraintsWithMap(node_map, &node_def));
834   EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName));
835 }
836 
TEST(FormatNodeForErrorTest,Node)837 TEST(FormatNodeForErrorTest, Node) {
838   Graph g(OpRegistry::Global());
839   Node* node;
840   TF_CHECK_OK(NodeBuilder("enter", "NoOp").Finalize(&g, &node));
841   EXPECT_EQ("{{node enter}}", FormatNodeForError(*node));
842 }
843 
TEST(FormatNodeForErrorTest,NodeDef)844 TEST(FormatNodeForErrorTest, NodeDef) {
845   NodeDef node_def;
846   node_def.set_name("enter");
847   node_def.set_op("Enter");
848   AddNodeAttr("frame_name", "test_frame", &node_def);
849   EXPECT_EQ("{{node enter}}", FormatNodeDefForError(node_def));
850 }
851 
TEST(FormatNodeForErrorTest,NodeDefWithOriginalNames)852 TEST(FormatNodeForErrorTest, NodeDefWithOriginalNames) {
853   NodeDef node_def;
854   node_def.set_name("enter");
855   node_def.set_op("Enter");
856   AddNodeAttr("frame_name", "test_frame", &node_def);
857   *(node_def.mutable_experimental_debug_info()->add_original_node_names()) =
858       "node_name";
859   *(node_def.mutable_experimental_debug_info()->add_original_func_names()) =
860       "func_name";
861   EXPECT_EQ("{{function_node func_name}}{{node node_name}}",
862             FormatNodeDefForError(node_def));
863   *(node_def.mutable_experimental_debug_info()->add_original_node_names()) =
864       "node_name2";
865   *(node_def.mutable_experimental_debug_info()->add_original_func_names()) =
866       "func_name2";
867   EXPECT_EQ(
868       "{{function_node func_name}}{{node node_name}}, "
869       "{{function_node func_name2}}{{node node_name2}}",
870       FormatNodeDefForError(node_def));
871 }
872 
TEST(AttachDef,AllowMultipleFormattedNode)873 TEST(AttachDef, AllowMultipleFormattedNode) {
874   NodeDef a;
875   a.set_name("a");
876   NodeDef b;
877   b.set_name("b");
878   Status s = Status(error::CANCELLED, "Error");
879   Status s2 = AttachDef(s, a, true);
880   EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
881   Status s3 = AttachDef(s2, b, true);
882   EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[{{node b}}]]", s3.error_message());
883 }
884 
TEST(AttachDef,DisallowMultipleFormattedNode)885 TEST(AttachDef, DisallowMultipleFormattedNode) {
886   NodeDef a;
887   a.set_name("a");
888   NodeDef b;
889   b.set_name("b");
890   Status s = Status(error::CANCELLED, "Error");
891   Status s2 = AttachDef(s, a, false);
892   EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
893   Status s3 = AttachDef(s2, b, false);
894   EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[b]]", s3.error_message());
895 }
896 
897 }  // namespace
898 }  // namespace tensorflow
899