1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_def_builder.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/lib/core/stringpiece.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace tensorflow {
29 namespace {
30
CanonicalizeAttrTypeListOrder(OpDef * def)31 static void CanonicalizeAttrTypeListOrder(OpDef* def) {
32 for (int i = 0; i < def->attr_size(); i++) {
33 AttrValue* a = def->mutable_attr(i)->mutable_allowed_values();
34 std::sort(a->mutable_list()->mutable_type()->begin(),
35 a->mutable_list()->mutable_type()->end());
36 }
37 }
38
39 class OpDefBuilderTest : public ::testing::Test {
40 protected:
b()41 OpDefBuilder b() { return OpDefBuilder("Test"); }
42
ExpectSuccess(const OpDefBuilder & builder,StringPiece proto,OpShapeInferenceFn * shape_fn_out=nullptr)43 void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto,
44 OpShapeInferenceFn* shape_fn_out = nullptr) {
45 OpRegistrationData op_reg_data;
46 Status status = builder.Finalize(&op_reg_data);
47 TF_EXPECT_OK(status);
48 OpDef& op_def = op_reg_data.op_def;
49 if (status.ok()) {
50 OpDef expected;
51 protobuf::TextFormat::ParseFromString(
52 strings::StrCat("name: 'Test' ", proto), &expected);
53 // Allow different orderings
54 CanonicalizeAttrTypeListOrder(&op_def);
55 CanonicalizeAttrTypeListOrder(&expected);
56 EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
57
58 if (shape_fn_out) {
59 *shape_fn_out = op_reg_data.shape_inference_fn;
60 }
61 }
62 }
63
ExpectOrdered(const OpDefBuilder & builder,StringPiece proto)64 void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) {
65 OpRegistrationData op_reg_data;
66 Status status = builder.Finalize(&op_reg_data);
67 TF_EXPECT_OK(status);
68 OpDef& op_def = op_reg_data.op_def;
69 if (status.ok()) {
70 OpDef expected;
71 protobuf::TextFormat::ParseFromString(
72 strings::StrCat("name: 'Test' ", proto), &expected);
73 EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
74 }
75 }
76
ExpectFailure(const OpDefBuilder & builder,const string & error)77 void ExpectFailure(const OpDefBuilder& builder, const string& error) {
78 OpRegistrationData op_reg_data;
79 Status status = builder.Finalize(&op_reg_data);
80 EXPECT_FALSE(status.ok());
81 if (!status.ok()) {
82 EXPECT_EQ(status.error_message(), error);
83 }
84 }
85 };
86
TEST_F(OpDefBuilderTest,Attr)87 TEST_F(OpDefBuilderTest, Attr) {
88 ExpectSuccess(b().Attr("a:string"), "attr: { name: 'a' type: 'string' }");
89 ExpectSuccess(b().Attr("A: int"), "attr: { name: 'A' type: 'int' }");
90 ExpectSuccess(b().Attr("a1 :float"), "attr: { name: 'a1' type: 'float' }");
91 ExpectSuccess(b().Attr("a_a : bool"), "attr: { name: 'a_a' type: 'bool' }");
92 ExpectSuccess(b().Attr("aB : type"), "attr: { name: 'aB' type: 'type' }");
93 ExpectSuccess(b().Attr("aB_3\t: shape"),
94 "attr: { name: 'aB_3' type: 'shape' }");
95 ExpectSuccess(b().Attr("t: tensor"), "attr: { name: 't' type: 'tensor' }");
96 ExpectSuccess(b().Attr("XYZ\t:\tlist(type)"),
97 "attr: { name: 'XYZ' type: 'list(type)' }");
98 ExpectSuccess(b().Attr("f: func"), "attr { name: 'f' type: 'func'}");
99 }
100
TEST_F(OpDefBuilderTest,AttrFailure)101 TEST_F(OpDefBuilderTest, AttrFailure) {
102 ExpectFailure(
103 b().Attr("_:string"),
104 "Trouble parsing '<name>:' from Attr(\"_:string\") for Op Test");
105 ExpectFailure(
106 b().Attr("9:string"),
107 "Trouble parsing '<name>:' from Attr(\"9:string\") for Op Test");
108 ExpectFailure(b().Attr(":string"),
109 "Trouble parsing '<name>:' from Attr(\":string\") for Op Test");
110 ExpectFailure(b().Attr("string"),
111 "Trouble parsing '<name>:' from Attr(\"string\") for Op Test");
112 ExpectFailure(b().Attr("a:invalid"),
113 "Trouble parsing type string at 'invalid' from "
114 "Attr(\"a:invalid\") for Op Test");
115 ExpectFailure(
116 b().Attr("b:"),
117 "Trouble parsing type string at '' from Attr(\"b:\") for Op Test");
118 }
119
TEST_F(OpDefBuilderTest,AttrWithRestrictions)120 TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
121 // Types with restrictions.
122 ExpectSuccess(
123 b().Attr("a:numbertype"),
124 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
125 "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
126 "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
127 "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16] } } }");
128 ExpectSuccess(
129 b().Attr("a:{numbertype, variant}"),
130 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
131 "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
132 "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
133 "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16, DT_VARIANT] } } }");
134 ExpectSuccess(b().Attr("a:realnumbertype"),
135 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
136 "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
137 "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, "
138 "DT_BFLOAT16] } } }");
139 ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"),
140 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
141 "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
142 "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, "
143 "DT_BFLOAT16, DT_VARIANT, DT_STRING] } } }");
144 ExpectSuccess(b().Attr("a:quantizedtype"),
145 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
146 "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
147 ExpectSuccess(b().Attr("a:{quantizedtype ,string}"),
148 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
149 "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, "
150 "DT_STRING]} } }");
151 ExpectSuccess(b().Attr("a:{string,int32}"),
152 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
153 "[DT_STRING, DT_INT32] } } }");
154 ExpectSuccess(b().Attr("a: { float , complex64 } "),
155 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
156 "[DT_FLOAT, DT_COMPLEX64] } } }");
157 ExpectSuccess(b().Attr("a: {float, complex64,} "),
158 "attr: { name: 'a' type: 'type' allowed_values { list { type: "
159 "[DT_FLOAT, DT_COMPLEX64] } }");
160 ExpectSuccess(b().Attr(R"(a: { "X", "yz" })"),
161 "attr: { name: 'a' type: 'string' allowed_values { list { s: "
162 "['X', 'yz'] } } }");
163 ExpectSuccess(b().Attr(R"(a: { "X", "yz", })"),
164 "attr: { name: 'a' type: 'string' allowed_values { list { s: "
165 "['X', 'yz'] } } }");
166 ExpectSuccess(
167 b().Attr("i: int >= -5"),
168 "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }");
169 ExpectSuccess(b().Attr("i: int >= 9223372036854775807"),
170 ("attr: { name: 'i' type: 'int' has_minimum: true "
171 "minimum: 9223372036854775807 }"));
172 ExpectSuccess(b().Attr("i: int >= -9223372036854775808"),
173 ("attr: { name: 'i' type: 'int' has_minimum: true "
174 "minimum: -9223372036854775808 }"));
175 }
176
TEST_F(OpDefBuilderTest,AttrRestrictionFailure)177 TEST_F(OpDefBuilderTest, AttrRestrictionFailure) {
178 ExpectFailure(
179 b().Attr("a:{}"),
180 "Trouble parsing type string at '}' from Attr(\"a:{}\") for Op Test");
181 ExpectFailure(
182 b().Attr("a:{,}"),
183 "Trouble parsing type string at ',}' from Attr(\"a:{,}\") for Op Test");
184 ExpectFailure(b().Attr("a:{invalid}"),
185 "Unrecognized type string 'invalid' from Attr(\"a:{invalid}\") "
186 "for Op Test");
187 ExpectFailure(b().Attr("a:{\"str\", float}"),
188 "Trouble parsing allowed string at 'float}' from "
189 "Attr(\"a:{\"str\", float}\") for Op Test");
190 ExpectFailure(b().Attr("a:{ float, \"str\" }"),
191 "Trouble parsing type string at '\"str\" }' from Attr(\"a:{ "
192 "float, \"str\" }\") for Op Test");
193 ExpectFailure(b().Attr("a:{float,,string}"),
194 "Trouble parsing type string at ',string}' from "
195 "Attr(\"a:{float,,string}\") for Op Test");
196 ExpectFailure(b().Attr("a:{float,,}"),
197 "Trouble parsing type string at ',}' from "
198 "Attr(\"a:{float,,}\") for Op Test");
199 ExpectFailure(b().Attr("i: int >= a"),
200 "Could not parse integer lower limit after '>=', "
201 "found ' a' instead from Attr(\"i: int >= a\") for Op Test");
202 ExpectFailure(b().Attr("i: int >= -a"),
203 "Could not parse integer lower limit after '>=', found ' -a' "
204 "instead from Attr(\"i: int >= -a\") for Op Test");
205 ExpectFailure(b().Attr("i: int >= 9223372036854775808"),
206 "Could not parse integer lower limit after '>=', found "
207 "' 9223372036854775808' instead from "
208 "Attr(\"i: int >= 9223372036854775808\") for Op Test");
209 ExpectFailure(b().Attr("i: int >= -9223372036854775809"),
210 "Could not parse integer lower limit after '>=', found "
211 "' -9223372036854775809' instead from "
212 "Attr(\"i: int >= -9223372036854775809\") for Op Test");
213 }
214
TEST_F(OpDefBuilderTest,AttrListOfRestricted)215 TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
216 ExpectSuccess(
217 b().Attr("a:list(realnumbertype)"),
218 "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
219 "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
220 "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64"
221 "] } } }");
222 ExpectSuccess(
223 b().Attr("a:list({realnumbertype, variant})"),
224 "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
225 "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
226 "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64, "
227 "DT_VARIANT] } } }");
228 ExpectSuccess(
229 b().Attr("a:list(quantizedtype)"),
230 "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
231 "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16] } } }");
232 ExpectSuccess(
233 b().Attr("a: list({float, string, bool})"),
234 "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
235 "[DT_FLOAT, DT_STRING, DT_BOOL] } } }");
236 ExpectSuccess(
237 b().Attr(R"(a: list({ "one fish", "two fish" }))"),
238 "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
239 "['one fish', 'two fish'] } } }");
240 ExpectSuccess(
241 b().Attr(R"(a: list({ 'red fish', 'blue fish' }))"),
242 "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
243 "['red fish', 'blue fish'] } } }");
244 ExpectSuccess(
245 b().Attr(R"(a: list({ "single' ", 'double"' }))"),
246 "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
247 "[\"single' \", 'double\"'] } } }");
248 ExpectSuccess(
249 b().Attr(R"(a: list({ 'escape\'\n', "from\\\"NY" }))"),
250 "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
251 "[\"escape'\\n\", 'from\\\\\"NY'] } } }");
252 }
253
TEST_F(OpDefBuilderTest,AttrListWithMinLength)254 TEST_F(OpDefBuilderTest, AttrListWithMinLength) {
255 ExpectSuccess(
256 b().Attr("i: list(bool) >= 4"),
257 "attr: { name: 'i' type: 'list(bool)' has_minimum: true minimum: 4 }");
258 }
259
TEST_F(OpDefBuilderTest,AttrWithDefaults)260 TEST_F(OpDefBuilderTest, AttrWithDefaults) {
261 ExpectSuccess(b().Attr(R"(a:string="foo")"),
262 "attr: { name: 'a' type: 'string' default_value { s:'foo' } }");
263 ExpectSuccess(b().Attr(R"(a:string='foo')"),
264 "attr: { name: 'a' type: 'string' default_value { s:'foo' } }");
265 ExpectSuccess(b().Attr("a:float = 1.25"),
266 "attr: { name: 'a' type: 'float' default_value { f: 1.25 } }");
267 ExpectSuccess(b().Attr("a:tensor = { dtype: DT_INT32 int_val: 5 }"),
268 "attr: { name: 'a' type: 'tensor' default_value { tensor {"
269 " dtype: DT_INT32 int_val: 5 } } }");
270 ExpectSuccess(b().Attr("a:shape = { dim { size: 3 } dim { size: 4 } }"),
271 "attr: { name: 'a' type: 'shape' default_value { shape {"
272 " dim { size: 3 } dim { size: 4 } } } }");
273 ExpectSuccess(b().Attr("a:shape = { dim { size: -1 } dim { size: 4 } }"),
274 "attr: { name: 'a' type: 'shape' default_value { shape {"
275 " dim { size: -1 } dim { size: 4 } } } }");
276 }
277
TEST_F(OpDefBuilderTest,AttrFailedDefaults)278 TEST_F(OpDefBuilderTest, AttrFailedDefaults) {
279 ExpectFailure(b().Attr(R"(a:int="foo")"),
280 "Could not parse default value '\"foo\"' from "
281 "Attr(\"a:int=\"foo\"\") for Op Test");
282 ExpectFailure(b().Attr("a:float = [1.25]"),
283 "Could not parse default value '[1.25]' from Attr(\"a:float = "
284 "[1.25]\") for Op Test");
285 }
286
TEST_F(OpDefBuilderTest,AttrListWithDefaults)287 TEST_F(OpDefBuilderTest, AttrListWithDefaults) {
288 ExpectSuccess(b().Attr(R"(a:list(string)=["foo", "bar"])"),
289 "attr: { name: 'a' type: 'list(string)' "
290 "default_value { list { s: ['foo', 'bar'] } } }");
291 ExpectSuccess(b().Attr("a:list(bool)=[true, false, true]"),
292 "attr: { name: 'a' type: 'list(bool)' "
293 "default_value { list { b: [true, false, true] } } }");
294 ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"),
295 "attr: { name: 'a' type: 'list(int)' "
296 "default_value { list { i: [0, -1, 2, -4, 8] } } }");
297 ExpectSuccess(b().Attr(R"(a:list(int)=[ ])"),
298 "attr: { name: 'a' type: 'list(int)' "
299 "default_value { list { i: [] } } }");
300 }
301
TEST_F(OpDefBuilderTest,AttrFailedListDefaults)302 TEST_F(OpDefBuilderTest, AttrFailedListDefaults) {
303 ExpectFailure(b().Attr(R"(a:list(int)=["foo"])"),
304 "Could not parse default value '[\"foo\"]' from "
305 "Attr(\"a:list(int)=[\"foo\"]\") for Op Test");
306 ExpectFailure(b().Attr(R"(a:list(int)=[7, "foo"])"),
307 "Could not parse default value '[7, \"foo\"]' from "
308 "Attr(\"a:list(int)=[7, \"foo\"]\") for Op Test");
309 ExpectFailure(b().Attr("a:list(float) = [[1.25]]"),
310 "Could not parse default value '[[1.25]]' from "
311 "Attr(\"a:list(float) = [[1.25]]\") for Op Test");
312 ExpectFailure(b().Attr("a:list(float) = 1.25"),
313 "Could not parse default value '1.25' from "
314 "Attr(\"a:list(float) = 1.25\") for Op Test");
315 ExpectFailure(b().Attr(R"(a:list(string)='foo')"),
316 "Could not parse default value ''foo'' from "
317 "Attr(\"a:list(string)='foo'\") for Op Test");
318 ExpectFailure(b().Attr("a:list(float) = ["),
319 "Could not parse default value '[' from "
320 "Attr(\"a:list(float) = [\") for Op Test");
321 ExpectFailure(b().Attr("a:list(float) = "),
322 "Could not parse default value '' from "
323 "Attr(\"a:list(float) = \") for Op Test");
324 }
325
TEST_F(OpDefBuilderTest,InputOutput)326 TEST_F(OpDefBuilderTest, InputOutput) {
327 ExpectSuccess(b().Input("a: int32"),
328 "input_arg: { name: 'a' type: DT_INT32 }");
329 ExpectSuccess(b().Output("b: string"),
330 "output_arg: { name: 'b' type: DT_STRING }");
331 ExpectSuccess(b().Input("c: float "),
332 "input_arg: { name: 'c' type: DT_FLOAT }");
333 ExpectSuccess(b().Output("d: Ref ( bool ) "),
334 "output_arg: { name: 'd' type: DT_BOOL is_ref: true }");
335 ExpectOrdered(b().Input("a: bool")
336 .Output("c: complex64")
337 .Input("b: int64")
338 .Output("d: string"),
339 "input_arg: { name: 'a' type: DT_BOOL } "
340 "input_arg: { name: 'b' type: DT_INT64 } "
341 "output_arg: { name: 'c' type: DT_COMPLEX64 } "
342 "output_arg: { name: 'd' type: DT_STRING }");
343 }
344
TEST_F(OpDefBuilderTest,PolymorphicInputOutput)345 TEST_F(OpDefBuilderTest, PolymorphicInputOutput) {
346 ExpectSuccess(b().Input("a: foo").Attr("foo: type"),
347 "input_arg: { name: 'a' type_attr: 'foo' } "
348 "attr: { name: 'foo' type: 'type' }");
349 ExpectSuccess(b().Output("a: foo").Attr("foo: { bool, int32 }"),
350 "output_arg: { name: 'a' type_attr: 'foo' } "
351 "attr: { name: 'foo' type: 'type' "
352 "allowed_values: { list { type: [DT_BOOL, DT_INT32] } } }");
353 }
354
TEST_F(OpDefBuilderTest,InputOutputListSameType)355 TEST_F(OpDefBuilderTest, InputOutputListSameType) {
356 ExpectSuccess(b().Input("a: n * int32").Attr("n: int"),
357 "input_arg: { name: 'a' number_attr: 'n' type: DT_INT32 } "
358 "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 }");
359 // Polymorphic case:
360 ExpectSuccess(b().Output("b: n * foo").Attr("n: int").Attr("foo: type"),
361 "output_arg: { name: 'b' number_attr: 'n' type_attr: 'foo' } "
362 "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 } "
363 "attr: { name: 'foo' type: 'type' }");
364 }
365
TEST_F(OpDefBuilderTest,InputOutputListAnyType)366 TEST_F(OpDefBuilderTest, InputOutputListAnyType) {
367 ExpectSuccess(
368 b().Input("c: foo").Attr("foo: list(type)"),
369 "input_arg: { name: 'c' type_list_attr: 'foo' } "
370 "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 }");
371 ExpectSuccess(
372 b().Output("c: foo").Attr("foo: list({string, float})"),
373 "output_arg: { name: 'c' type_list_attr: 'foo' } "
374 "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 "
375 "allowed_values: { list { type: [DT_STRING, DT_FLOAT] } } }");
376 }
377
TEST_F(OpDefBuilderTest,InputOutputFailure)378 TEST_F(OpDefBuilderTest, InputOutputFailure) {
379 ExpectFailure(b().Input("9: int32"),
380 "Trouble parsing 'name:' from Input(\"9: int32\") for Op Test");
381 ExpectFailure(
382 b().Output("_: int32"),
383 "Trouble parsing 'name:' from Output(\"_: int32\") for Op Test");
384 ExpectFailure(b().Input(": int32"),
385 "Trouble parsing 'name:' from Input(\": int32\") for Op Test");
386 ExpectFailure(b().Output("int32"),
387 "Trouble parsing 'name:' from Output(\"int32\") for Op Test");
388 ExpectFailure(
389 b().Input("CAPS: int32"),
390 "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test");
391 ExpectFailure(
392 b().Input("_underscore: int32"),
393 "Trouble parsing 'name:' from Input(\"_underscore: int32\") for Op Test");
394 ExpectFailure(
395 b().Input("0digit: int32"),
396 "Trouble parsing 'name:' from Input(\"0digit: int32\") for Op Test");
397 ExpectFailure(b().Input("a: _"),
398 "Trouble parsing either a type or an attr name at '_' from "
399 "Input(\"a: _\") for Op Test");
400 ExpectFailure(b().Input("a: 9"),
401 "Trouble parsing either a type or an attr name at '9' from "
402 "Input(\"a: 9\") for Op Test");
403 ExpectFailure(b().Input("a: 9 * int32"),
404 "Trouble parsing either a type or an attr name at '9 * int32' "
405 "from Input(\"a: 9 * int32\") for Op Test");
406 ExpectFailure(
407 b().Input("a: x * _").Attr("x: type"),
408 "Extra '* _' unparsed at the end from Input(\"a: x * _\") for Op Test");
409 ExpectFailure(b().Input("a: x * y extra").Attr("x: int").Attr("y: type"),
410 "Extra 'extra' unparsed at the end from Input(\"a: x * y "
411 "extra\") for Op Test");
412 ExpectFailure(b().Input("a: Ref(int32"),
413 "Did not find closing ')' for 'Ref(', instead found: '' from "
414 "Input(\"a: Ref(int32\") for Op Test");
415 ExpectFailure(
416 b().Input("a: Ref"),
417 "Reference to unknown attr 'Ref' from Input(\"a: Ref\") for Op Test");
418 ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"),
419 "Did not find closing ')' for 'Ref(', instead found: 'y' from "
420 "Input(\"a: Ref(x y\") for Op Test");
421 ExpectFailure(
422 b().Input("a: x"),
423 "Reference to unknown attr 'x' from Input(\"a: x\") for Op Test");
424 ExpectFailure(
425 b().Input("a: x * y").Attr("x: int"),
426 "Reference to unknown attr 'y' from Input(\"a: x * y\") for Op Test");
427 ExpectFailure(b().Input("a: x").Attr("x: int"),
428 "Reference to attr 'x' with type int that isn't type or "
429 "list(type) from Input(\"a: x\") for Op Test");
430 }
431
TEST_F(OpDefBuilderTest,Set)432 TEST_F(OpDefBuilderTest, Set) {
433 ExpectSuccess(b().SetIsStateful(), "is_stateful: true");
434 ExpectSuccess(b().SetIsCommutative().SetIsAggregate(),
435 "is_commutative: true is_aggregate: true");
436 }
437
TEST_F(OpDefBuilderTest,DocUnpackSparseFeatures)438 TEST_F(OpDefBuilderTest, DocUnpackSparseFeatures) {
439 ExpectOrdered(b().Input("sf: string")
440 .Output("indices: int32")
441 .Output("ids: int64")
442 .Output("weights: float")
443 .Doc(R"doc(
444 Converts a vector of strings with dist_belief::SparseFeatures to tensors.
445
446 Note that indices, ids and weights are vectors of the same size and have
447 one-to-one correspondence between their elements. ids and weights are each
448 obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in
449 1...size(sf). Note that if sf[i].weight is not set, the default value for the
450 weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were
451 extracted from sf[i], then index[j] is set to i.
452
453 sf: vector of string, where each element is the string encoding of
454 SparseFeatures proto.
455 indices: vector of indices inside sf
456 ids: vector of id extracted from the SparseFeatures proto.
457 weights: vector of weight extracted from the SparseFeatures proto.
458 )doc"),
459 R"proto(
460 input_arg {
461 name: "sf"
462 description: "vector of string, where each element is the string encoding of\nSparseFeatures proto."
463 type: DT_STRING
464 }
465 output_arg {
466 name: "indices"
467 description: "vector of indices inside sf"
468 type: DT_INT32
469 }
470 output_arg {
471 name: "ids"
472 description: "vector of id extracted from the SparseFeatures proto."
473 type: DT_INT64
474 }
475 output_arg {
476 name: "weights"
477 description: "vector of weight extracted from the SparseFeatures proto."
478 type: DT_FLOAT
479 }
480 summary: "Converts a vector of strings with dist_belief::SparseFeatures to tensors."
481 description: "Note that indices, ids and weights are vectors of the same size and have\none-to-one correspondence between their elements. ids and weights are each\nobtained by sequentially concatenating sf[i].id and sf[i].weight, for i in\n1...size(sf). Note that if sf[i].weight is not set, the default value for the\nweight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were\nextracted from sf[i], then index[j] is set to i."
482 )proto");
483 }
484
TEST_F(OpDefBuilderTest,DocConcat)485 TEST_F(OpDefBuilderTest, DocConcat) {
486 ExpectOrdered(b().Input("concat_dim: int32")
487 .Input("values: num_values * dtype")
488 .Output("output: dtype")
489 .Attr("dtype: type")
490 .Attr("num_values: int >= 2")
491 .Doc(R"doc(
492 Concatenate N Tensors along one dimension.
493
494 concat_dim: The (scalar) dimension along which to concatenate. Must be
495 in the range [0, rank(values...)).
496 values: The N Tensors to concatenate. Their ranks and types must match,
497 and their sizes must match in all dimensions except concat_dim.
498 output: A Tensor with the concatenation of values stacked along the
499 concat_dim dimension. This Tensor's shape matches the Tensors in
500 values, except in concat_dim where it has the sum of the sizes.
501 )doc"),
502 R"proto(
503 input_arg {
504 name: "concat_dim"
505 description: "The (scalar) dimension along which to concatenate. Must be\nin the range [0, rank(values...))."
506 type: DT_INT32
507 }
508 input_arg {
509 name: "values"
510 description: "The N Tensors to concatenate. Their ranks and types must match,\nand their sizes must match in all dimensions except concat_dim."
511 type_attr: "dtype"
512 number_attr: "num_values"
513 }
514 output_arg {
515 name: "output"
516 description: "A Tensor with the concatenation of values stacked along the\nconcat_dim dimension. This Tensor\'s shape matches the Tensors in\nvalues, except in concat_dim where it has the sum of the sizes."
517 type_attr: "dtype"
518 }
519 summary: "Concatenate N Tensors along one dimension."
520 attr {
521 name: "dtype"
522 type: "type"
523 }
524 attr {
525 name: "num_values"
526 type: "int"
527 has_minimum: true
528 minimum: 2
529 }
530 )proto");
531 }
532
533 TEST_F(OpDefBuilderTest, DocAttr) {
534 ExpectOrdered(b().Attr("i: int").Doc(R"doc(
535 Summary
536
537 i: How much to operate.
538 )doc"),
539 R"proto(
540 summary: "Summary"
541 attr {
542 name: "i"
543 type: "int"
544 description: "How much to operate."
545 }
546 )proto");
547 }
548
549 TEST_F(OpDefBuilderTest, DocCalledTwiceFailure) {
550 ExpectFailure(b().Doc("What's").Doc("up, doc?"),
551 "Extra call to Doc() for Op Test");
552 }
553
554 TEST_F(OpDefBuilderTest, DocFailureMissingName) {
555 ExpectFailure(
556 b().Input("a: int32").Doc(R"doc(
557 Summary
558
559 a: Something for a.
560 b: b is not defined.
561 )doc"),
562 "No matching input/output/attr for name 'b' from Doc() for Op Test");
563
564 ExpectFailure(
565 b().Input("a: int32").Doc(R"doc(
566 Summary
567
568 b: b is not defined and by itself.
569 )doc"),
570 "No matching input/output/attr for name 'b' from Doc() for Op Test");
571 }
572
573 TEST_F(OpDefBuilderTest, DefaultMinimum) {
574 ExpectSuccess(b().Input("values: num_values * dtype")
575 .Output("output: anything")
576 .Attr("anything: list(type)")
577 .Attr("dtype: type")
578 .Attr("num_values: int"),
579 R"proto(
580 input_arg {
581 name: "values"
582 type_attr: "dtype"
583 number_attr: "num_values"
584 }
585 output_arg {
586 name: "output"
587 type_list_attr: "anything"
588 }
589 attr {
590 name: "anything"
591 type: "list(type)"
592 has_minimum: true
593 minimum: 1
594 }
595 attr {
596 name: "dtype"
597 type: "type"
598 }
599 attr {
600 name: "num_values"
601 type: "int"
602 has_minimum: true
603 minimum: 1
604 }
605 )proto");
606 }
607
608 TEST_F(OpDefBuilderTest, SetShapeFn) {
609 auto fn = [](shape_inference::InferenceContext* c) {
610 return errors::Unknown("ShapeFn was called");
611 };
612 OpShapeInferenceFn fn_out;
613 ExpectSuccess(
614 b().SetShapeFn(fn).Attr("dtype: type"),
615 "attr { name: \"dtype\" type: \"type\" allowed_values { list { } } }",
616 &fn_out);
617 ASSERT_TRUE(fn_out != nullptr);
618 EXPECT_EQ("ShapeFn was called", fn_out(nullptr).error_message());
619 }
620
621 TEST_F(OpDefBuilderTest, SetShapeFnCalledTwiceFailure) {
622 auto fn = [](shape_inference::InferenceContext* c) {
623 return errors::Unknown("ShapeFn was called");
624 };
625 ExpectFailure(b().SetShapeFn(fn).SetShapeFn(fn),
626 "SetShapeFn called twice for Op Test");
627 }
628
629 TEST_F(OpDefBuilderTest, ResourceIsStateful) {
630 OpRegistrationData op_reg_data;
631 TF_EXPECT_OK(b().Input("a: resource").Finalize(&op_reg_data));
632 const OpDef& op_def = op_reg_data.op_def;
633 EXPECT_TRUE(op_def.is_stateful());
634 }
635
636 } // namespace
637 } // namespace tensorflow
638