• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"  // NOLINT
19 #include "tensorflow/core/framework/fake_input.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30 
31 namespace tensorflow {
32 namespace {
33 
ToOpDef(const OpDefBuilder & builder)34 OpDef ToOpDef(const OpDefBuilder& builder) {
35   OpRegistrationData op_reg_data;
36   TF_EXPECT_OK(builder.Finalize(&op_reg_data));
37   return op_reg_data.op_def;
38 }
39 
ToNodeDef(const string & text)40 NodeDef ToNodeDef(const string& text) {
41   NodeDef node_def;
42   EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
43   return node_def;
44 }
45 
ToNodeDef(NodeDefBuilder && builder)46 NodeDef ToNodeDef(NodeDefBuilder&& builder) {
47   NodeDef node_def;
48   TF_EXPECT_OK(builder.Finalize(&node_def));
49   return node_def;
50 }
51 
ExpectSuccess(const NodeDef & good,const OpDef & op_def)52 void ExpectSuccess(const NodeDef& good, const OpDef& op_def) {
53   EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def))
54       << "NodeDef: " << SummarizeNodeDef(good)
55       << "; OpDef: " << SummarizeOpDef(op_def);
56 }
57 
ExpectFailure(const NodeDef & bad,const OpDef & op_def,const string & message)58 void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
59                    const string& message) {
60   Status status = ValidateNodeDef(bad, op_def);
61 
62   EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad)
63                             << "; OpDef: " << SummarizeOpDef(op_def);
64   if (status.ok()) return;
65 
66   EXPECT_TRUE(errors::IsInvalidArgument(status))
67       << status << "; NodeDef: " << SummarizeNodeDef(bad)
68       << "; OpDef: " << SummarizeOpDef(op_def);
69 
70   LOG(INFO) << "Message: " << status.error_message();
71   EXPECT_TRUE(absl::StrContains(status.ToString(), message))
72       << "NodeDef: " << SummarizeNodeDef(bad)
73       << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
74       << "\nDoes not contain: " << message;
75 }
76 
TEST(NodeDefUtilTest,In)77 TEST(NodeDefUtilTest, In) {
78   const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type"));
79   const NodeDef node_def = ToNodeDef(R"proto(
80     name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } }
81     )proto");
82   ExpectSuccess(node_def, op);
83 
84   EXPECT_EQ("{{node n}} = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
85 
86   // Mismatching Op names.
87   NodeDef bad = node_def;
88   bad.set_op("Wrong");
89   ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;");
90 
91   // Missing attr
92   bad = node_def;
93   bad.clear_attr();
94   ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;");
95 
96   // Extra attr
97   bad = node_def;
98   AddNodeAttr("EXTRA", 17, &bad);
99   ExpectFailure(bad, op, "NodeDef mentions attr 'EXTRA' not in Op<name=In;");
100 
101   // Attr has wrong type
102   bad = node_def;
103   bad.clear_attr();
104   AddNodeAttr("T", 17, &bad);
105   ExpectFailure(
106       bad, op,
107       "AttrValue had value with type 'int' when 'type' expected\n\t for attr "
108       "'T'\n\t; NodeDef: ");
109 
110   // Wrong number of inputs
111   bad = node_def;
112   bad.add_input("b");
113   ExpectFailure(
114       bad, op,
115       "NodeDef expected inputs 'float' do not match 2 inputs specified;");
116 
117   bad = node_def;
118   bad.clear_input();
119   ExpectFailure(
120       bad, op,
121       "NodeDef expected inputs 'float' do not match 0 inputs specified;");
122 
123   // Control inputs must appear after data inputs
124   NodeDef good = node_def;
125   good.add_input("^b");
126   ExpectSuccess(node_def, op);
127 
128   bad = node_def;
129   bad.clear_input();
130   bad.add_input("^b");
131   bad.add_input("a");
132   ExpectFailure(bad, op,
133                 "Invalid argument: Non-control input 'a' after control input "
134                 "in NodeDef:");
135 
136   bad = node_def;
137   bad.add_input("^b:0");
138   ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:");
139 }
140 
TEST(NodeDefUtilTest,Out)141 TEST(NodeDefUtilTest, Out) {
142   const OpDef op =
143       ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype"));
144   const NodeDef node_def = ToNodeDef(R"proto(
145     name:'n' op:'Out' attr { key:'T' value { type:DT_INT32 } }
146     )proto");
147   ExpectSuccess(node_def, op);
148 
149   EXPECT_EQ("{{node n}} = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
150 
151   // Non-number type.
152   NodeDef bad = node_def;
153   bad.clear_attr();
154   AddNodeAttr("T", DT_STRING, &bad);
155   ExpectFailure(bad, op,
156                 "Value for attr 'T' of string is not in the list of allowed "
157                 "values: float, double, int32, uint8, int16, int8, complex64, "
158                 "int64, qint8, quint8, qint32, bfloat16, uint16, complex128, "
159                 "half, uint32, uint64");
160 }
161 
TEST(NodeDefUtilTest,Enum)162 TEST(NodeDefUtilTest, Enum) {
163   const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}"));
164   const NodeDef node_def = ToNodeDef(R"proto(
165     name:'n' op:'Enum' attr { key:'e' value { s:'apple' } }
166     )proto");
167   ExpectSuccess(node_def, op);
168 
169   EXPECT_EQ("{{node n}} = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
170 
171   NodeDef good = node_def;
172   good.clear_attr();
173   AddNodeAttr("e", "orange", &good);
174   ExpectSuccess(good, op);
175 
176   // Non-allowed value.
177   NodeDef bad = node_def;
178   bad.clear_attr();
179   AddNodeAttr("e", "foo", &bad);
180   ExpectFailure(bad, op,
181                 "Value for attr 'e' of \"foo\" is not in the list of allowed "
182                 "values: \"apple\", \"orange\"");
183 }
184 
TEST(NodeDefUtilTest,SameIn)185 TEST(NodeDefUtilTest, SameIn) {
186   const OpDef op = ToOpDef(OpDefBuilder("SameIn")
187                                .Input("i: N * T")
188                                .Attr("N: int >= 2")
189                                .Attr("T: {float,double}"));
190   const NodeDef node_def = ToNodeDef(R"proto(
191     name:'n' op:'SameIn' input:'a' input:'b'
192     attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } }
193     )proto");
194   ExpectSuccess(node_def, op);
195 
196   EXPECT_EQ("{{node n}} = SameIn[N=2, T=DT_DOUBLE](a, b)",
197             SummarizeNodeDef(node_def));
198 
199   // Illegal type
200   NodeDef bad = ToNodeDef(R"proto(
201     name:'n' op:'SameIn' input:'a' input:'b'
202     attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } }
203     )proto");
204   ExpectFailure(bad, op,
205                 "Value for attr 'T' of string is not in the list of allowed "
206                 "values: float, double");
207 
208   // Too few inputs
209   bad = ToNodeDef(R"proto(
210     name:'n' op:'SameIn' input:'a' input:'b'
211     attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } }
212     )proto");
213   ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2");
214 }
215 
TEST(NodeDefUtilTest,AnyIn)216 TEST(NodeDefUtilTest, AnyIn) {
217   const OpDef op =
218       ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1"));
219 
220   const NodeDef node_def = ToNodeDef(R"proto(
221     name:'n' op:'AnyIn' input:'a' input:'b'
222     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
223     )proto");
224   ExpectSuccess(node_def, op);
225 
226   EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
227             SummarizeNodeDef(node_def));
228 
229   const NodeDef bad = ToNodeDef(R"proto(
230     name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } }
231     )proto");
232   ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1");
233 
234   // With proto3 semantics, an empty value {} is indistinguishable from a value
235   // with an empty list in it. So we simply expect to get a message complaining
236   // about empty list for value {}.
237   const NodeDef bad2 = ToNodeDef(R"proto(
238     name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } }
239     )proto");
240   ExpectFailure(bad2, op,
241                 "Length for attr 'T' of 0 must be at least minimum 1");
242 }
243 
TEST(NodeDefUtilTest,Device)244 TEST(NodeDefUtilTest, Device) {
245   const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
246   const NodeDef node_def1 =
247       ToNodeDef(std::move(NodeDefBuilder("d", &op_def1).Device("/cpu:17")));
248   ExpectSuccess(node_def1, op_def1);
249   EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()",
250             SummarizeNodeDef(node_def1));
251 
252   const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
253   const NodeDef node_def2 = ToNodeDef(
254       std::move(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")));
255   ExpectSuccess(node_def2, op_def2);
256   EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()",
257             SummarizeNodeDef(node_def2));
258 }
259 
ExpectValidSyntax(const NodeDef & good)260 void ExpectValidSyntax(const NodeDef& good) {
261   EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good))
262       << "NodeDef: " << SummarizeNodeDef(good);
263 }
264 
ExpectInvalidSyntax(const NodeDef & bad,const string & message)265 void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
266   Status status = ValidateExternalNodeDefSyntax(bad);
267 
268   ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad);
269 
270   EXPECT_TRUE(errors::IsInvalidArgument(status))
271       << status << "; NodeDef: " << SummarizeNodeDef(bad);
272 
273   EXPECT_TRUE(absl::StrContains(StringPiece(status.ToString()), message))
274       << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
275       << message;
276 }
277 
TEST(NodeDefUtilTest,ValidSyntax)278 TEST(NodeDefUtilTest, ValidSyntax) {
279   const NodeDef node_def = ToNodeDef(R"proto(
280     name:'n' op:'AnyIn' input:'a' input:'b'
281     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
282     )proto");
283   ExpectValidSyntax(node_def);
284 
285   const NodeDef node_def_namespace = ToNodeDef(R"proto(
286     name: 'n'
287     op: 'Project>AnyIn'
288     input: 'a'
289     input: 'b'
290     attr {
291       key: 'T'
292       value { list { type: [ DT_INT32, DT_STRING ] } }
293     }
294   )proto");
295   ExpectValidSyntax(node_def_namespace);
296 
297   const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto(
298     name: 'n'
299     op: 'AnyIn'
300     input: 'a:0'
301     input: 'b:123'
302     attr {
303       key: 'T'
304       value { list { type: [ DT_INT32, DT_STRING ] } }
305     }
306   )proto");
307   ExpectValidSyntax(node_def_explicit_inputs);
308 
309   EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
310             SummarizeNodeDef(node_def_explicit_inputs));
311 
312   const NodeDef node_def_explicit_inputs_namespace = ToNodeDef(R"proto(
313     name: 'Project>n'
314     op: 'Project>AnyIn'
315     input: 'Project>a:0'
316     input: 'Project>b:123'
317     input: '^Project>c'
318     attr {
319       key: 'T'
320       value { list { type: [ DT_INT32, DT_STRING ] } }
321     }
322   )proto");
323   ExpectValidSyntax(node_def_explicit_inputs_namespace);
324 
325   EXPECT_EQ(
326       "{{node Project>n}} = Project>AnyIn[T=[DT_INT32, DT_STRING]]"
327       "(Project>a:0, Project>b:123, ^Project>c)",
328       SummarizeNodeDef(node_def_explicit_inputs_namespace));
329 
330   const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
331     name:'n' op:'AnyIn'
332     attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } }
333     )proto");
334   ExpectValidSyntax(node_def_partial_shape);
335 
336   const NodeDef node_def_control_input = ToNodeDef(R"proto(
337     name:'n-' op:'AnyIn' input:'a' input:'^b'
338     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
339     )proto");
340   ExpectValidSyntax(node_def_control_input);
341 
342   const NodeDef node_def_invalid_name = ToNodeDef(R"proto(
343     name:'n:0' op:'AnyIn' input:'a' input:'b'
344     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
345     )proto");
346   ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'");
347 
348   const NodeDef node_def_internal_name = ToNodeDef(R"proto(
349     name:'_n' op:'AnyIn' input:'a' input:'b'
350     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
351     )proto");
352   ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
353 
354   const NodeDef node_def_slash_in_name = ToNodeDef(R"proto(
355     name:'n\\' op:'AnyIn' input:'a' input:'b'
356     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
357     )proto");
358   ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'");
359 
360   const NodeDef node_def_internal_input_name = ToNodeDef(R"proto(
361     name:'n' op:'AnyIn' input:'_a' input:'b'
362     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
363     )proto");
364   ExpectInvalidSyntax(node_def_internal_input_name,
365                       "Illegal op input name '_a'");
366 
367   const NodeDef node_def_input_name_slash = ToNodeDef(R"proto(
368     name:'n' op:'AnyIn' input:'a\\' input:'b'
369     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
370     )proto");
371   ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'");
372 
373   const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto(
374     name:'n' op:'AnyIn' input:'a' input:'^b:0'
375     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
376     )proto");
377   ExpectInvalidSyntax(node_def_invalid_control_input_name,
378                       "Illegal op input name '^b:0'");
379 
380   const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto(
381     name:'n' op:'AnyIn' input:'a' input:'^b\\'
382     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
383     )proto");
384   ExpectInvalidSyntax(node_def_control_input_name_slash,
385                       "Illegal op input name '^b\\'");
386 
387   const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto(
388     name:'n' op:'AnyIn' input:'^a' input:'b'
389     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
390     )proto");
391   ExpectInvalidSyntax(node_def_data_input_after_control,
392                       "All control inputs must follow all data inputs");
393 
394   const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto(
395     name:'n' op:'AnyIn' input:'a:b' input:'b'
396     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
397     )proto");
398   ExpectInvalidSyntax(node_def_data_input_invalid_port,
399                       "Illegal op input name 'a:b");
400 
401   const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto(
402     name:'n' op:'AnyIn' input:'a:00' input:'b'
403     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
404     )proto");
405   ExpectInvalidSyntax(node_def_data_input_invalid_port2,
406                       "Illegal op input name 'a:00");
407 }
408 
TEST(InputTypesForNode,Simple)409 TEST(InputTypesForNode, Simple) {
410   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
411                                    .Input("a: float")
412                                    .Input("b: int32")
413                                    .Output("c: string")
414                                    .Output("d: bool"));
415   const NodeDef node_def = ToNodeDef(std::move(
416       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
417   DataTypeVector types;
418   EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
419   EXPECT_EQ(types[0], DT_FLOAT);
420   EXPECT_EQ(types[1], DT_INT32);
421 
422   DataType type;
423   EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
424   EXPECT_EQ(type, DT_FLOAT);
425   EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
426   EXPECT_EQ(type, DT_INT32);
427   EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
428 }
429 
TEST(OutputTypesForNode,Simple)430 TEST(OutputTypesForNode, Simple) {
431   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
432                                    .Input("a: float")
433                                    .Input("b: int32")
434                                    .Output("c: string")
435                                    .Output("d: bool"));
436   const NodeDef node_def = ToNodeDef(std::move(
437       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
438   DataTypeVector types;
439   EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
440   EXPECT_EQ(types[0], DT_STRING);
441   EXPECT_EQ(types[1], DT_BOOL);
442 
443   DataType type;
444   EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
445   EXPECT_EQ(type, DT_STRING);
446   EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
447   EXPECT_EQ(type, DT_BOOL);
448   EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
449 }
450 
TEST(OutputTypesForNode,LargeOutput)451 TEST(OutputTypesForNode, LargeOutput) {
452   const OpDef op_def = ToOpDef(OpDefBuilder("TestSplitOp")
453                                    .Input("value: int64")
454                                    .Output("output: num_split * int64")
455                                    .Attr("num_split: int >= 1"));
456   int64 num_split = 1000000000000;
457   const NodeDef node_def =
458       ToNodeDef(std::move(NodeDefBuilder("test_split_op", &op_def)
459                               .Input(FakeInput())
460                               .Attr("num_split", num_split)));
461   DataTypeVector types;
462   EXPECT_FALSE(OutputTypesForNode(node_def, op_def, &types).ok());
463 }
464 
TEST(OutputTypesForNode_AttrSliceOverload,Simple)465 TEST(OutputTypesForNode_AttrSliceOverload, Simple) {
466   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
467                                    .Input("a: float")
468                                    .Input("b: int32")
469                                    .Output("c: string")
470                                    .Output("d: bool"));
471   const AttrSlice attr_slice =
472       AttrSlice(ToNodeDef(std::move(NodeDefBuilder("simple", &op_def)
473                                         .Input(FakeInput())
474                                         .Input(FakeInput()))));
475   DataTypeVector types;
476   EXPECT_TRUE(OutputTypesForNode(attr_slice, op_def, &types).ok());
477   EXPECT_EQ(types[0], DT_STRING);
478   EXPECT_EQ(types[1], DT_BOOL);
479 }
480 
TEST(NameRangesForNodeTest,Simple)481 TEST(NameRangesForNodeTest, Simple) {
482   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
483                                    .Input("a: float")
484                                    .Input("b: int32")
485                                    .Output("c: string")
486                                    .Output("d: bool"));
487   NameRangeMap inputs, outputs;
488   const NodeDef node_def = ToNodeDef(std::move(
489       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
490   TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
491   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
492   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
493 
494   EXPECT_EQ("{{node simple}} = Simple[](a, b)", SummarizeNodeDef(node_def));
495 
496   OpDef bad_op_def = op_def;
497   bad_op_def.mutable_input_arg(0)->clear_type();
498   EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok());
499 }
500 
TEST(NameRangesForNodeTest,Polymorphic)501 TEST(NameRangesForNodeTest, Polymorphic) {
502   const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic")
503                                    .Input("a: T")
504                                    .Input("b: T")
505                                    .Output("c: T")
506                                    .Attr("T: type"));
507   NameRangeMap inputs, outputs;
508   const NodeDef node_def1 =
509       ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
510                               .Input(FakeInput(DT_INT32))
511                               .Input(FakeInput(DT_INT32))));
512   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
513   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
514   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
515   EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)",
516             SummarizeNodeDef(node_def1));
517 
518   const NodeDef node_def2 =
519       ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
520                               .Input(FakeInput(DT_BOOL))
521                               .Input(FakeInput(DT_BOOL))));
522   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
523   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
524   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
525   EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_BOOL](a, b)",
526             SummarizeNodeDef(node_def2));
527 }
528 
TEST(NameRangesForNodeTest,NRepeats)529 TEST(NameRangesForNodeTest, NRepeats) {
530   const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats")
531                                    .Input("a: N * int32")
532                                    .Input("b: N * T")
533                                    .Output("c: T")
534                                    .Output("d: N * string")
535                                    .Output("e: M * bool")
536                                    .Attr("N: int")
537                                    .Attr("M: int")
538                                    .Attr("T: type"));
539   NameRangeMap inputs, outputs;
540   const NodeDef node_def1 =
541       ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
542                               .Input(FakeInput(4, DT_INT32))
543                               .Input(FakeInput(4, DT_FLOAT))
544                               .Attr("M", 3)));
545   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
546   EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
547   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
548             outputs);
549   EXPECT_EQ(
550       "{{node nr}} = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, "
551       "b:2, b:3)",
552       SummarizeNodeDef(node_def1));
553 
554   const NodeDef node_def2 =
555       ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
556                               .Input(FakeInput(2, DT_INT32))
557                               .Input(FakeInput(2, DT_DOUBLE))
558                               .Attr("M", 7)));
559   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
560   EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
561   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
562             outputs);
563   EXPECT_EQ("{{node nr}} = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
564             SummarizeNodeDef(node_def2));
565 
566   NodeDef bad_node_def = node_def2;
567   bad_node_def.clear_attr();
568   EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
569 }
570 
TEST(NameRangesForNodeTest,TypeList)571 TEST(NameRangesForNodeTest, TypeList) {
572   const OpDef op_def = ToOpDef(OpDefBuilder("TypeList")
573                                    .Input("a: T1")
574                                    .Input("b: T2")
575                                    .Output("c: T2")
576                                    .Output("d: T3")
577                                    .Output("e: T1")
578                                    .Attr("T1: list(type)")
579                                    .Attr("T2: list(type)")
580                                    .Attr("T3: list(type)"));
581   NameRangeMap inputs, outputs;
582   const NodeDef node_def1 =
583       ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
584                               .Input(FakeInput({DT_BOOL, DT_FLOAT}))
585                               .Input(FakeInput(4, DT_FLOAT))
586                               .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})));
587   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
588   EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
589   EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
590             outputs);
591   EXPECT_EQ(
592       "{{node tl}} = TypeList[T1=[DT_BOOL, DT_FLOAT],"
593       " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
594       " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
595       SummarizeNodeDef(node_def1));
596 
597   const NodeDef node_def2 =
598       ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
599                               .Input(FakeInput(7, DT_INT32))
600                               .Input(FakeInput({DT_DOUBLE}))
601                               .Attr("T3", {DT_DOUBLE, DT_STRING})));
602   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
603   EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
604   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
605             outputs);
606   EXPECT_EQ(
607       "{{node tl}} = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, "
608       "DT_INT32,"
609       " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
610       "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
611       SummarizeNodeDef(node_def2));
612 
613   NodeDef bad_node_def = node_def2;
614   bad_node_def.clear_attr();
615   EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
616 }
617 
TEST(AddPrefixAndSuffixToNode,Enter)618 TEST(AddPrefixAndSuffixToNode, Enter) {
619   NodeDef node_def;
620   node_def.set_name("enter");
621   node_def.set_op("Enter");
622   AddNodeAttr("frame_name", "test_frame", &node_def);
623   const string prefix = "prefix/";
624   const string suffix = "/suffix";
625   TF_ASSERT_OK(AddPrefixAndSuffixToNode(prefix, suffix, &node_def));
626   EXPECT_EQ("prefix/enter/suffix", node_def.name());
627   string frame_name;
628   TF_ASSERT_OK(GetNodeAttr(node_def, "frame_name", &frame_name));
629   EXPECT_EQ("prefix/test_frame/suffix", frame_name);
630 }
631 
TEST(MaybeAddPrefixToColocationConstraints,Basic)632 TEST(MaybeAddPrefixToColocationConstraints, Basic) {
633   NodeDef node_def;
634   node_def.set_name("Identity");
635   node_def.set_op("Identity");
636   AddNodeAttr(kColocationAttrName,
637               {strings::StrCat(kColocationGroupPrefix, "Node1"),
638                strings::StrCat(kColocationGroupPrefix, "Node2"),
639                strings::StrCat(kColocationGroupPrefix, "Node3")},
640               &node_def);
641 
642   std::unordered_set<string> match;
643   match.insert("Node1");
644   match.insert("Node3");
645   TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
646   std::vector<string> coloc_constraints;
647   TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints));
648   EXPECT_EQ(
649       coloc_constraints,
650       std::vector<string>({"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"}));
651 }
652 
TEST(MaybeAddPrefixToColocationConstraints,NoConstraints)653 TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) {
654   NodeDef node_def;
655   node_def.set_name("Identity");
656   node_def.set_op("Identity");
657 
658   std::unordered_set<string> match;
659   match.insert("Node1");
660   match.insert("Node3");
661   TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
662   EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName));
663 }
664 
TEST(FormatNodeForErrorTest,Node)665 TEST(FormatNodeForErrorTest, Node) {
666   Graph g(OpRegistry::Global());
667   Node* node;
668   TF_CHECK_OK(NodeBuilder("enter", "NoOp").Finalize(&g, &node));
669   EXPECT_EQ("{{node enter}}", FormatNodeForError(*node));
670 }
671 
TEST(FormatNodeForErrorTest,NodeDef)672 TEST(FormatNodeForErrorTest, NodeDef) {
673   NodeDef node_def;
674   node_def.set_name("enter");
675   node_def.set_op("Enter");
676   AddNodeAttr("frame_name", "test_frame", &node_def);
677   EXPECT_EQ("{{node enter}}", FormatNodeDefForError(node_def));
678 }
679 
TEST(AttachDef,AllowMultipleFormattedNode)680 TEST(AttachDef, AllowMultipleFormattedNode) {
681   NodeDef a;
682   a.set_name("a");
683   NodeDef b;
684   b.set_name("b");
685   Status s = Status(error::CANCELLED, "Error");
686   Status s2 = AttachDef(s, a, true);
687   EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
688   Status s3 = AttachDef(s2, b, true);
689   EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[{{node b}}]]", s3.error_message());
690 }
691 
TEST(AttachDef,DisallowMultipleFormattedNode)692 TEST(AttachDef, DisallowMultipleFormattedNode) {
693   NodeDef a;
694   a.set_name("a");
695   NodeDef b;
696   b.set_name("b");
697   Status s = Status(error::CANCELLED, "Error");
698   Status s2 = AttachDef(s, a, false);
699   EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
700   Status s3 = AttachDef(s2, b, false);
701   EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[b]]", s3.error_message());
702 }
703 
704 }  // namespace
705 }  // namespace tensorflow
706