• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_def_builder.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/lib/core/stringpiece.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace tensorflow {
29 namespace {
30 
CanonicalizeAttrTypeListOrder(OpDef * def)31 static void CanonicalizeAttrTypeListOrder(OpDef* def) {
32   for (int i = 0; i < def->attr_size(); i++) {
33     AttrValue* a = def->mutable_attr(i)->mutable_allowed_values();
34     std::sort(a->mutable_list()->mutable_type()->begin(),
35               a->mutable_list()->mutable_type()->end());
36   }
37 }
38 
39 class OpDefBuilderTest : public ::testing::Test {
40  protected:
b()41   OpDefBuilder b() { return OpDefBuilder("Test"); }
42 
ExpectSuccess(const OpDefBuilder & builder,StringPiece proto,OpShapeInferenceFn * shape_fn_out=nullptr)43   void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto,
44                      OpShapeInferenceFn* shape_fn_out = nullptr) {
45     OpRegistrationData op_reg_data;
46     Status status = builder.Finalize(&op_reg_data);
47     TF_EXPECT_OK(status);
48     OpDef& op_def = op_reg_data.op_def;
49     if (status.ok()) {
50       OpDef expected;
51       protobuf::TextFormat::ParseFromString(
52           strings::StrCat("name: 'Test' ", proto), &expected);
53       // Allow different orderings
54       CanonicalizeAttrTypeListOrder(&op_def);
55       CanonicalizeAttrTypeListOrder(&expected);
56       EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
57 
58       if (shape_fn_out) {
59         *shape_fn_out = op_reg_data.shape_inference_fn;
60       }
61     }
62   }
63 
ExpectOrdered(const OpDefBuilder & builder,StringPiece proto)64   void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) {
65     OpRegistrationData op_reg_data;
66     Status status = builder.Finalize(&op_reg_data);
67     TF_EXPECT_OK(status);
68     OpDef& op_def = op_reg_data.op_def;
69     if (status.ok()) {
70       OpDef expected;
71       protobuf::TextFormat::ParseFromString(
72           strings::StrCat("name: 'Test' ", proto), &expected);
73       EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
74     }
75   }
76 
ExpectFailure(const OpDefBuilder & builder,const string & error)77   void ExpectFailure(const OpDefBuilder& builder, const string& error) {
78     OpRegistrationData op_reg_data;
79     Status status = builder.Finalize(&op_reg_data);
80     EXPECT_FALSE(status.ok());
81     if (!status.ok()) {
82       EXPECT_EQ(status.error_message(), error);
83     }
84   }
85 };
86 
TEST_F(OpDefBuilderTest,Attr)87 TEST_F(OpDefBuilderTest, Attr) {
88   ExpectSuccess(b().Attr("a:string"), "attr: { name: 'a' type: 'string' }");
89   ExpectSuccess(b().Attr("A: int"), "attr: { name: 'A' type: 'int' }");
90   ExpectSuccess(b().Attr("a1 :float"), "attr: { name: 'a1' type: 'float' }");
91   ExpectSuccess(b().Attr("a_a : bool"), "attr: { name: 'a_a' type: 'bool' }");
92   ExpectSuccess(b().Attr("aB  :  type"), "attr: { name: 'aB' type: 'type' }");
93   ExpectSuccess(b().Attr("aB_3\t: shape"),
94                 "attr: { name: 'aB_3' type: 'shape' }");
95   ExpectSuccess(b().Attr("t: tensor"), "attr: { name: 't' type: 'tensor' }");
96   ExpectSuccess(b().Attr("XYZ\t:\tlist(type)"),
97                 "attr: { name: 'XYZ' type: 'list(type)' }");
98   ExpectSuccess(b().Attr("f: func"), "attr { name: 'f' type: 'func'}");
99 }
100 
TEST_F(OpDefBuilderTest,AttrFailure)101 TEST_F(OpDefBuilderTest, AttrFailure) {
102   ExpectFailure(
103       b().Attr("_:string"),
104       "Trouble parsing '<name>:' from Attr(\"_:string\") for Op Test");
105   ExpectFailure(
106       b().Attr("9:string"),
107       "Trouble parsing '<name>:' from Attr(\"9:string\") for Op Test");
108   ExpectFailure(b().Attr(":string"),
109                 "Trouble parsing '<name>:' from Attr(\":string\") for Op Test");
110   ExpectFailure(b().Attr("string"),
111                 "Trouble parsing '<name>:' from Attr(\"string\") for Op Test");
112   ExpectFailure(b().Attr("a:invalid"),
113                 "Trouble parsing type string at 'invalid' from "
114                 "Attr(\"a:invalid\") for Op Test");
115   ExpectFailure(
116       b().Attr("b:"),
117       "Trouble parsing type string at '' from Attr(\"b:\") for Op Test");
118 }
119 
TEST_F(OpDefBuilderTest,AttrWithRestrictions)120 TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
121   // Types with restrictions.
122   ExpectSuccess(
123       b().Attr("a:numbertype"),
124       "attr: { name: 'a' type: 'type' allowed_values { list { type: "
125       "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
126       "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
127       "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16] } } }");
128   ExpectSuccess(
129       b().Attr("a:{numbertype, variant}"),
130       "attr: { name: 'a' type: 'type' allowed_values { list { type: "
131       "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
132       "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
133       "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16, DT_VARIANT] } } }");
134   ExpectSuccess(b().Attr("a:realnumbertype"),
135                 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
136                 "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
137                 "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, "
138                 "DT_BFLOAT16] } } }");
139   ExpectSuccess(b().Attr("a:{realnumbertype,  variant , string, }"),
140                 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
141                 "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
142                 "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, "
143                 "DT_BFLOAT16, DT_VARIANT, DT_STRING] } } }");
144   ExpectSuccess(b().Attr("a:quantizedtype"),
145                 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
146                 "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
147   ExpectSuccess(b().Attr("a:{quantizedtype  ,string}"),
148                 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
149                 "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, "
150                 "DT_STRING]} } }");
151   ExpectSuccess(b().Attr("a:{string,int32}"),
152                 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
153                 "[DT_STRING, DT_INT32] } } }");
154   ExpectSuccess(b().Attr("a: { float , complex64 } "),
155                 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
156                 "[DT_FLOAT, DT_COMPLEX64] } } }");
157   ExpectSuccess(b().Attr("a: {float, complex64,} "),
158                 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
159                 "[DT_FLOAT, DT_COMPLEX64] } }");
160   ExpectSuccess(b().Attr(R"(a: { "X", "yz" })"),
161                 "attr: { name: 'a' type: 'string' allowed_values { list { s: "
162                 "['X', 'yz'] } } }");
163   ExpectSuccess(b().Attr(R"(a: { "X", "yz", })"),
164                 "attr: { name: 'a' type: 'string' allowed_values { list { s: "
165                 "['X', 'yz'] } } }");
166   ExpectSuccess(
167       b().Attr("i: int >= -5"),
168       "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }");
169   ExpectSuccess(b().Attr("i: int >= 9223372036854775807"),
170                 ("attr: { name: 'i' type: 'int' has_minimum: true "
171                  "minimum: 9223372036854775807 }"));
172   ExpectSuccess(b().Attr("i: int >= -9223372036854775808"),
173                 ("attr: { name: 'i' type: 'int' has_minimum: true "
174                  "minimum: -9223372036854775808 }"));
175 }
176 
TEST_F(OpDefBuilderTest,AttrRestrictionFailure)177 TEST_F(OpDefBuilderTest, AttrRestrictionFailure) {
178   ExpectFailure(
179       b().Attr("a:{}"),
180       "Trouble parsing type string at '}' from Attr(\"a:{}\") for Op Test");
181   ExpectFailure(
182       b().Attr("a:{,}"),
183       "Trouble parsing type string at ',}' from Attr(\"a:{,}\") for Op Test");
184   ExpectFailure(b().Attr("a:{invalid}"),
185                 "Unrecognized type string 'invalid' from Attr(\"a:{invalid}\") "
186                 "for Op Test");
187   ExpectFailure(b().Attr("a:{\"str\", float}"),
188                 "Trouble parsing allowed string at 'float}' from "
189                 "Attr(\"a:{\"str\", float}\") for Op Test");
190   ExpectFailure(b().Attr("a:{ float, \"str\" }"),
191                 "Trouble parsing type string at '\"str\" }' from Attr(\"a:{ "
192                 "float, \"str\" }\") for Op Test");
193   ExpectFailure(b().Attr("a:{float,,string}"),
194                 "Trouble parsing type string at ',string}' from "
195                 "Attr(\"a:{float,,string}\") for Op Test");
196   ExpectFailure(b().Attr("a:{float,,}"),
197                 "Trouble parsing type string at ',}' from "
198                 "Attr(\"a:{float,,}\") for Op Test");
199   ExpectFailure(b().Attr("i: int >= a"),
200                 "Could not parse integer lower limit after '>=', "
201                 "found ' a' instead from Attr(\"i: int >= a\") for Op Test");
202   ExpectFailure(b().Attr("i: int >= -a"),
203                 "Could not parse integer lower limit after '>=', found ' -a' "
204                 "instead from Attr(\"i: int >= -a\") for Op Test");
205   ExpectFailure(b().Attr("i: int >= 9223372036854775808"),
206                 "Could not parse integer lower limit after '>=', found "
207                 "' 9223372036854775808' instead from "
208                 "Attr(\"i: int >= 9223372036854775808\") for Op Test");
209   ExpectFailure(b().Attr("i: int >= -9223372036854775809"),
210                 "Could not parse integer lower limit after '>=', found "
211                 "' -9223372036854775809' instead from "
212                 "Attr(\"i: int >= -9223372036854775809\") for Op Test");
213 }
214 
TEST_F(OpDefBuilderTest,AttrListOfRestricted)215 TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
216   ExpectSuccess(
217       b().Attr("a:list(realnumbertype)"),
218       "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
219       "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
220       "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64"
221       "] } } }");
222   ExpectSuccess(
223       b().Attr("a:list({realnumbertype, variant})"),
224       "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
225       "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
226       "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64, "
227       "DT_VARIANT] } } }");
228   ExpectSuccess(
229       b().Attr("a:list(quantizedtype)"),
230       "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
231       "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16] } } }");
232   ExpectSuccess(
233       b().Attr("a: list({float, string, bool})"),
234       "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
235       "[DT_FLOAT, DT_STRING, DT_BOOL] } } }");
236   ExpectSuccess(
237       b().Attr(R"(a: list({ "one fish", "two fish" }))"),
238       "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
239       "['one fish', 'two fish'] } } }");
240   ExpectSuccess(
241       b().Attr(R"(a: list({ 'red fish', 'blue fish' }))"),
242       "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
243       "['red fish', 'blue fish'] } } }");
244   ExpectSuccess(
245       b().Attr(R"(a: list({ "single' ", 'double"' }))"),
246       "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
247       "[\"single' \", 'double\"'] } } }");
248   ExpectSuccess(
249       b().Attr(R"(a: list({ 'escape\'\n', "from\\\"NY" }))"),
250       "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
251       "[\"escape'\\n\", 'from\\\\\"NY'] } } }");
252 }
253 
TEST_F(OpDefBuilderTest,AttrListWithMinLength)254 TEST_F(OpDefBuilderTest, AttrListWithMinLength) {
255   ExpectSuccess(
256       b().Attr("i: list(bool) >= 4"),
257       "attr: { name: 'i' type: 'list(bool)' has_minimum: true minimum: 4 }");
258 }
259 
TEST_F(OpDefBuilderTest,AttrWithDefaults)260 TEST_F(OpDefBuilderTest, AttrWithDefaults) {
261   ExpectSuccess(b().Attr(R"(a:string="foo")"),
262                 "attr: { name: 'a' type: 'string' default_value { s:'foo' } }");
263   ExpectSuccess(b().Attr(R"(a:string='foo')"),
264                 "attr: { name: 'a' type: 'string' default_value { s:'foo' } }");
265   ExpectSuccess(b().Attr("a:float = 1.25"),
266                 "attr: { name: 'a' type: 'float' default_value { f: 1.25 } }");
267   ExpectSuccess(b().Attr("a:tensor = { dtype: DT_INT32 int_val: 5 }"),
268                 "attr: { name: 'a' type: 'tensor' default_value { tensor {"
269                 "    dtype: DT_INT32 int_val: 5 } } }");
270   ExpectSuccess(b().Attr("a:shape = { dim { size: 3 } dim { size: 4 } }"),
271                 "attr: { name: 'a' type: 'shape' default_value { shape {"
272                 "    dim { size: 3 } dim { size: 4 } } } }");
273   ExpectSuccess(b().Attr("a:shape = { dim { size: -1 } dim { size: 4 } }"),
274                 "attr: { name: 'a' type: 'shape' default_value { shape {"
275                 "    dim { size: -1 } dim { size: 4 } } } }");
276 }
277 
TEST_F(OpDefBuilderTest,AttrFailedDefaults)278 TEST_F(OpDefBuilderTest, AttrFailedDefaults) {
279   ExpectFailure(b().Attr(R"(a:int="foo")"),
280                 "Could not parse default value '\"foo\"' from "
281                 "Attr(\"a:int=\"foo\"\") for Op Test");
282   ExpectFailure(b().Attr("a:float = [1.25]"),
283                 "Could not parse default value '[1.25]' from Attr(\"a:float = "
284                 "[1.25]\") for Op Test");
285 }
286 
TEST_F(OpDefBuilderTest,AttrListWithDefaults)287 TEST_F(OpDefBuilderTest, AttrListWithDefaults) {
288   ExpectSuccess(b().Attr(R"(a:list(string)=["foo", "bar"])"),
289                 "attr: { name: 'a' type: 'list(string)' "
290                 "default_value { list { s: ['foo', 'bar'] } } }");
291   ExpectSuccess(b().Attr("a:list(bool)=[true, false, true]"),
292                 "attr: { name: 'a' type: 'list(bool)' "
293                 "default_value { list { b: [true, false, true] } } }");
294   ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"),
295                 "attr: { name: 'a' type: 'list(int)' "
296                 "default_value { list { i: [0, -1, 2, -4, 8] } } }");
297   ExpectSuccess(b().Attr(R"(a:list(int)=[  ])"),
298                 "attr: { name: 'a' type: 'list(int)' "
299                 "default_value { list { i: [] } } }");
300 }
301 
TEST_F(OpDefBuilderTest,AttrFailedListDefaults)302 TEST_F(OpDefBuilderTest, AttrFailedListDefaults) {
303   ExpectFailure(b().Attr(R"(a:list(int)=["foo"])"),
304                 "Could not parse default value '[\"foo\"]' from "
305                 "Attr(\"a:list(int)=[\"foo\"]\") for Op Test");
306   ExpectFailure(b().Attr(R"(a:list(int)=[7, "foo"])"),
307                 "Could not parse default value '[7, \"foo\"]' from "
308                 "Attr(\"a:list(int)=[7, \"foo\"]\") for Op Test");
309   ExpectFailure(b().Attr("a:list(float) = [[1.25]]"),
310                 "Could not parse default value '[[1.25]]' from "
311                 "Attr(\"a:list(float) = [[1.25]]\") for Op Test");
312   ExpectFailure(b().Attr("a:list(float) = 1.25"),
313                 "Could not parse default value '1.25' from "
314                 "Attr(\"a:list(float) = 1.25\") for Op Test");
315   ExpectFailure(b().Attr(R"(a:list(string)='foo')"),
316                 "Could not parse default value ''foo'' from "
317                 "Attr(\"a:list(string)='foo'\") for Op Test");
318   ExpectFailure(b().Attr("a:list(float) = ["),
319                 "Could not parse default value '[' from "
320                 "Attr(\"a:list(float) = [\") for Op Test");
321   ExpectFailure(b().Attr("a:list(float) = "),
322                 "Could not parse default value '' from "
323                 "Attr(\"a:list(float) = \") for Op Test");
324 }
325 
TEST_F(OpDefBuilderTest,InputOutput)326 TEST_F(OpDefBuilderTest, InputOutput) {
327   ExpectSuccess(b().Input("a: int32"),
328                 "input_arg: { name: 'a' type: DT_INT32 }");
329   ExpectSuccess(b().Output("b: string"),
330                 "output_arg: { name: 'b' type: DT_STRING }");
331   ExpectSuccess(b().Input("c: float  "),
332                 "input_arg: { name: 'c' type: DT_FLOAT }");
333   ExpectSuccess(b().Output("d: Ref  (  bool ) "),
334                 "output_arg: { name: 'd' type: DT_BOOL is_ref: true }");
335   ExpectOrdered(b().Input("a: bool")
336                     .Output("c: complex64")
337                     .Input("b: int64")
338                     .Output("d: string"),
339                 "input_arg: { name: 'a' type: DT_BOOL } "
340                 "input_arg: { name: 'b' type: DT_INT64 } "
341                 "output_arg: { name: 'c' type: DT_COMPLEX64 } "
342                 "output_arg: { name: 'd' type: DT_STRING }");
343 }
344 
TEST_F(OpDefBuilderTest,PolymorphicInputOutput)345 TEST_F(OpDefBuilderTest, PolymorphicInputOutput) {
346   ExpectSuccess(b().Input("a: foo").Attr("foo: type"),
347                 "input_arg: { name: 'a' type_attr: 'foo' } "
348                 "attr: { name: 'foo' type: 'type' }");
349   ExpectSuccess(b().Output("a: foo").Attr("foo: { bool, int32 }"),
350                 "output_arg: { name: 'a' type_attr: 'foo' } "
351                 "attr: { name: 'foo' type: 'type' "
352                 "allowed_values: { list { type: [DT_BOOL, DT_INT32] } } }");
353 }
354 
TEST_F(OpDefBuilderTest,InputOutputListSameType)355 TEST_F(OpDefBuilderTest, InputOutputListSameType) {
356   ExpectSuccess(b().Input("a: n * int32").Attr("n: int"),
357                 "input_arg: { name: 'a' number_attr: 'n' type: DT_INT32 } "
358                 "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 }");
359   // Polymorphic case:
360   ExpectSuccess(b().Output("b: n * foo").Attr("n: int").Attr("foo: type"),
361                 "output_arg: { name: 'b' number_attr: 'n' type_attr: 'foo' } "
362                 "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 } "
363                 "attr: { name: 'foo' type: 'type' }");
364 }
365 
TEST_F(OpDefBuilderTest,InputOutputListAnyType)366 TEST_F(OpDefBuilderTest, InputOutputListAnyType) {
367   ExpectSuccess(
368       b().Input("c: foo").Attr("foo: list(type)"),
369       "input_arg: { name: 'c' type_list_attr: 'foo' } "
370       "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 }");
371   ExpectSuccess(
372       b().Output("c: foo").Attr("foo: list({string, float})"),
373       "output_arg: { name: 'c' type_list_attr: 'foo' } "
374       "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 "
375       "allowed_values: { list { type: [DT_STRING, DT_FLOAT] } } }");
376 }
377 
TEST_F(OpDefBuilderTest,InputOutputFailure)378 TEST_F(OpDefBuilderTest, InputOutputFailure) {
379   ExpectFailure(b().Input("9: int32"),
380                 "Trouble parsing 'name:' from Input(\"9: int32\") for Op Test");
381   ExpectFailure(
382       b().Output("_: int32"),
383       "Trouble parsing 'name:' from Output(\"_: int32\") for Op Test");
384   ExpectFailure(b().Input(": int32"),
385                 "Trouble parsing 'name:' from Input(\": int32\") for Op Test");
386   ExpectFailure(b().Output("int32"),
387                 "Trouble parsing 'name:' from Output(\"int32\") for Op Test");
388   ExpectFailure(
389       b().Input("CAPS: int32"),
390       "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test");
391   ExpectFailure(
392       b().Input("_underscore: int32"),
393       "Trouble parsing 'name:' from Input(\"_underscore: int32\") for Op Test");
394   ExpectFailure(
395       b().Input("0digit: int32"),
396       "Trouble parsing 'name:' from Input(\"0digit: int32\") for Op Test");
397   ExpectFailure(b().Input("a: _"),
398                 "Trouble parsing either a type or an attr name at '_' from "
399                 "Input(\"a: _\") for Op Test");
400   ExpectFailure(b().Input("a: 9"),
401                 "Trouble parsing either a type or an attr name at '9' from "
402                 "Input(\"a: 9\") for Op Test");
403   ExpectFailure(b().Input("a: 9 * int32"),
404                 "Trouble parsing either a type or an attr name at '9 * int32' "
405                 "from Input(\"a: 9 * int32\") for Op Test");
406   ExpectFailure(
407       b().Input("a: x * _").Attr("x: type"),
408       "Extra '* _' unparsed at the end from Input(\"a: x * _\") for Op Test");
409   ExpectFailure(b().Input("a: x * y extra").Attr("x: int").Attr("y: type"),
410                 "Extra 'extra' unparsed at the end from Input(\"a: x * y "
411                 "extra\") for Op Test");
412   ExpectFailure(b().Input("a: Ref(int32"),
413                 "Did not find closing ')' for 'Ref(', instead found: '' from "
414                 "Input(\"a: Ref(int32\") for Op Test");
415   ExpectFailure(
416       b().Input("a: Ref"),
417       "Reference to unknown attr 'Ref' from Input(\"a: Ref\") for Op Test");
418   ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"),
419                 "Did not find closing ')' for 'Ref(', instead found: 'y' from "
420                 "Input(\"a: Ref(x y\") for Op Test");
421   ExpectFailure(
422       b().Input("a: x"),
423       "Reference to unknown attr 'x' from Input(\"a: x\") for Op Test");
424   ExpectFailure(
425       b().Input("a: x * y").Attr("x: int"),
426       "Reference to unknown attr 'y' from Input(\"a: x * y\") for Op Test");
427   ExpectFailure(b().Input("a: x").Attr("x: int"),
428                 "Reference to attr 'x' with type int that isn't type or "
429                 "list(type) from Input(\"a: x\") for Op Test");
430 }
431 
TEST_F(OpDefBuilderTest,Set)432 TEST_F(OpDefBuilderTest, Set) {
433   ExpectSuccess(b().SetIsStateful(), "is_stateful: true");
434   ExpectSuccess(b().SetIsCommutative().SetIsAggregate(),
435                 "is_commutative: true is_aggregate: true");
436 }
437 
TEST_F(OpDefBuilderTest,DocUnpackSparseFeatures)438 TEST_F(OpDefBuilderTest, DocUnpackSparseFeatures) {
439   ExpectOrdered(b().Input("sf: string")
440                     .Output("indices: int32")
441                     .Output("ids: int64")
442                     .Output("weights: float")
443                     .Doc(R"doc(
444 Converts a vector of strings with dist_belief::SparseFeatures to tensors.
445 
446 Note that indices, ids and weights are vectors of the same size and have
447 one-to-one correspondence between their elements. ids and weights are each
448 obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in
449 1...size(sf). Note that if sf[i].weight is not set, the default value for the
450 weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were
451 extracted from sf[i], then index[j] is set to i.
452 
453 sf: vector of string, where each element is the string encoding of
454     SparseFeatures proto.
455 indices: vector of indices inside sf
456 ids: vector of id extracted from the SparseFeatures proto.
457 weights: vector of weight extracted from the SparseFeatures proto.
458 )doc"),
459                 R"proto(
460 input_arg {
461   name: "sf"
462   description: "vector of string, where each element is the string encoding of\nSparseFeatures proto."
463   type: DT_STRING
464 }
465 output_arg {
466   name: "indices"
467   description: "vector of indices inside sf"
468   type: DT_INT32
469 }
470 output_arg {
471   name: "ids"
472   description: "vector of id extracted from the SparseFeatures proto."
473   type: DT_INT64
474 }
475 output_arg {
476   name: "weights"
477   description: "vector of weight extracted from the SparseFeatures proto."
478   type: DT_FLOAT
479 }
480 summary: "Converts a vector of strings with dist_belief::SparseFeatures to tensors."
481 description: "Note that indices, ids and weights are vectors of the same size and have\none-to-one correspondence between their elements. ids and weights are each\nobtained by sequentially concatenating sf[i].id and sf[i].weight, for i in\n1...size(sf). Note that if sf[i].weight is not set, the default value for the\nweight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were\nextracted from sf[i], then index[j] is set to i."
482 )proto");
483 }
484 
TEST_F(OpDefBuilderTest,DocConcat)485 TEST_F(OpDefBuilderTest, DocConcat) {
486   ExpectOrdered(b().Input("concat_dim: int32")
487                     .Input("values: num_values * dtype")
488                     .Output("output: dtype")
489                     .Attr("dtype: type")
490                     .Attr("num_values: int >= 2")
491                     .Doc(R"doc(
492 Concatenate N Tensors along one dimension.
493 
494 concat_dim: The (scalar) dimension along which to concatenate.  Must be
495   in the range [0, rank(values...)).
496 values: The N Tensors to concatenate. Their ranks and types must match,
497   and their sizes must match in all dimensions except concat_dim.
498 output: A Tensor with the concatenation of values stacked along the
499   concat_dim dimension.  This Tensor's shape matches the Tensors in
500   values, except in concat_dim where it has the sum of the sizes.
501 )doc"),
502                 R"proto(
503 input_arg {
504   name: "concat_dim"
505   description: "The (scalar) dimension along which to concatenate.  Must be\nin the range [0, rank(values...))."
506   type: DT_INT32
507 }
508 input_arg {
509   name: "values"
510   description: "The N Tensors to concatenate. Their ranks and types must match,\nand their sizes must match in all dimensions except concat_dim."
511   type_attr: "dtype"
512   number_attr: "num_values"
513 }
514 output_arg {
515   name: "output"
516   description: "A Tensor with the concatenation of values stacked along the\nconcat_dim dimension.  This Tensor\'s shape matches the Tensors in\nvalues, except in concat_dim where it has the sum of the sizes."
517   type_attr: "dtype"
518 }
519 summary: "Concatenate N Tensors along one dimension."
520 attr {
521   name: "dtype"
522   type: "type"
523 }
524 attr {
525   name: "num_values"
526   type: "int"
527   has_minimum: true
528   minimum: 2
529 }
530 )proto");
531 }
532 
533 TEST_F(OpDefBuilderTest, DocAttr) {
534   ExpectOrdered(b().Attr("i: int").Doc(R"doc(
535 Summary
536 
537 i: How much to operate.
538 )doc"),
539                 R"proto(
540 summary: "Summary"
541 attr {
542   name: "i"
543   type: "int"
544   description: "How much to operate."
545 }
546 )proto");
547 }
548 
549 TEST_F(OpDefBuilderTest, DocCalledTwiceFailure) {
550   ExpectFailure(b().Doc("What's").Doc("up, doc?"),
551                 "Extra call to Doc() for Op Test");
552 }
553 
554 TEST_F(OpDefBuilderTest, DocFailureMissingName) {
555   ExpectFailure(
556       b().Input("a: int32").Doc(R"doc(
557 Summary
558 
559 a: Something for a.
560 b: b is not defined.
561 )doc"),
562       "No matching input/output/attr for name 'b' from Doc() for Op Test");
563 
564   ExpectFailure(
565       b().Input("a: int32").Doc(R"doc(
566 Summary
567 
568 b: b is not defined and by itself.
569 )doc"),
570       "No matching input/output/attr for name 'b' from Doc() for Op Test");
571 }
572 
573 TEST_F(OpDefBuilderTest, DefaultMinimum) {
574   ExpectSuccess(b().Input("values: num_values * dtype")
575                     .Output("output: anything")
576                     .Attr("anything: list(type)")
577                     .Attr("dtype: type")
578                     .Attr("num_values: int"),
579                 R"proto(
580 input_arg {
581   name: "values"
582   type_attr: "dtype"
583   number_attr: "num_values"
584 }
585 output_arg {
586   name: "output"
587   type_list_attr: "anything"
588 }
589 attr {
590   name: "anything"
591   type: "list(type)"
592   has_minimum: true
593   minimum: 1
594 }
595 attr {
596   name: "dtype"
597   type: "type"
598 }
599 attr {
600   name: "num_values"
601   type: "int"
602   has_minimum: true
603   minimum: 1
604 }
605 )proto");
606 }
607 
608 TEST_F(OpDefBuilderTest, SetShapeFn) {
609   auto fn = [](shape_inference::InferenceContext* c) {
610     return errors::Unknown("ShapeFn was called");
611   };
612   OpShapeInferenceFn fn_out;
613   ExpectSuccess(
614       b().SetShapeFn(fn).Attr("dtype: type"),
615       "attr { name: \"dtype\" type: \"type\" allowed_values { list { } } }",
616       &fn_out);
617   ASSERT_TRUE(fn_out != nullptr);
618   EXPECT_EQ("ShapeFn was called", fn_out(nullptr).error_message());
619 }
620 
621 TEST_F(OpDefBuilderTest, SetShapeFnCalledTwiceFailure) {
622   auto fn = [](shape_inference::InferenceContext* c) {
623     return errors::Unknown("ShapeFn was called");
624   };
625   ExpectFailure(b().SetShapeFn(fn).SetShapeFn(fn),
626                 "SetShapeFn called twice for Op Test");
627 }
628 
629 TEST_F(OpDefBuilderTest, ResourceIsStateful) {
630   OpRegistrationData op_reg_data;
631   TF_EXPECT_OK(b().Input("a: resource").Finalize(&op_reg_data));
632   const OpDef& op_def = op_reg_data.op_def;
633   EXPECT_TRUE(op_def.is_stateful());
634 }
635 
636 }  // namespace
637 }  // namespace tensorflow
638