1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/framework/python_op_gen.h"
17
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/framework/op_gen_lib.h"
21 #include "tensorflow/core/platform/test.h"
22
23 namespace tensorflow {
24 namespace {
25
ExpectHasSubstr(const string & s,const string & expected)26 void ExpectHasSubstr(const string& s, const string& expected) {
27 EXPECT_TRUE(absl::StrContains(s, expected))
28 << "'Generated ops does not contain '" << expected << "'";
29 }
30
ExpectDoesNotHaveSubstr(const string & s,const string & expected)31 void ExpectDoesNotHaveSubstr(const string& s, const string& expected) {
32 EXPECT_FALSE(absl::StrContains(s, expected))
33 << "'Generated ops contains '" << expected << "'";
34 }
35
ExpectSubstrOrder(const string & s,const string & before,const string & after)36 void ExpectSubstrOrder(const string& s, const string& before,
37 const string& after) {
38 int before_pos = s.find(before);
39 int after_pos = s.find(after);
40 ASSERT_NE(std::string::npos, before_pos);
41 ASSERT_NE(std::string::npos, after_pos);
42 EXPECT_LT(before_pos, after_pos) << before << "' is not before '" << after;
43 }
44
TEST(PythonOpGen,TypeAnnotateAllOps)45 TEST(PythonOpGen, TypeAnnotateAllOps) {
46 OpList ops;
47 OpRegistry::Global()->Export(false, &ops);
48
49 ApiDefMap api_def_map(ops);
50
51 std::unordered_set<string> type_annotate_ops;
52 for (const auto& op : ops.op()) {
53 type_annotate_ops.insert(op.name());
54 }
55
56 string code = GetPythonOps(ops, api_def_map, {}, "", type_annotate_ops);
57
58 const string all_types =
59 ", _dtypes.BFloat16, _dtypes.Bool, _dtypes.Complex128, "
60 "_dtypes.Complex64, "
61 "_dtypes.Float16, _dtypes.Float32, _dtypes.Float64, _dtypes.Half, "
62 "_dtypes.Int16, "
63 "_dtypes.Int32, _dtypes.Int64, _dtypes.Int8, _dtypes.QInt16, "
64 "_dtypes.QInt32, "
65 "_dtypes.QInt8, _dtypes.QUInt16, _dtypes.QUInt8, _dtypes.Resource, "
66 "_dtypes.String, "
67 "_dtypes.UInt16, _dtypes.UInt32, _dtypes.UInt64, _dtypes.UInt8, "
68 "_dtypes.Variant)";
69
70 const string fake_param_typevar =
71 "TV_FakeParam_dtype = TypeVar(\"TV_FakeParam_dtype\"" + all_types;
72 const string fake_param =
73 "def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, "
74 "ctx) -> _ops.Tensor[TV_FakeParam_dtype]:";
75 const string fake_param_fallback =
76 "def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, "
77 "ctx) -> _ops.Tensor[TV_FakeParam_dtype]:";
78
79 ExpectHasSubstr(code, fake_param_typevar);
80 ExpectHasSubstr(code, fake_param);
81 ExpectHasSubstr(code, fake_param_fallback);
82
83 const string to_bool_typevar =
84 "TV_ToBool_T = TypeVar(\"TV_ToBool_T\"" + all_types;
85 const string to_bool_ =
86 "def to_bool(input: _ops.Tensor[TV_ToBool_T], name=None) -> "
87 "_ops.Tensor[_dtypes.Bool]:";
88 const string to_bool_fallback =
89 "def to_bool_eager_fallback(input: _ops.Tensor[TV_ToBool_T], name, ctx) "
90 "-> _ops.Tensor[_dtypes.Bool]:";
91
92 ExpectHasSubstr(code, to_bool_typevar);
93 ExpectHasSubstr(code, to_bool_);
94 ExpectHasSubstr(code, to_bool_fallback);
95 }
96
TEST(PythonOpGen,TypeAnnotateSingleTypeTensor)97 TEST(PythonOpGen, TypeAnnotateSingleTypeTensor) {
98 constexpr char kBaseOpDef[] = R"(
99 op {
100 name: "Bar"
101 input_arg {
102 name: "x"
103 type: DT_STRING
104 }
105 input_arg {
106 name: "y"
107 type: DT_QINT8
108 }
109 output_arg {
110 name: "output"
111 type: DT_BOOL
112 }
113 summary: "Summary for op Bar."
114 description: "Description for op Bar."
115 }
116 )";
117
118 std::unordered_set<string> type_annotate_ops{"Bar"};
119
120 OpList op_defs;
121 OpRegistry::Global()->Export(false, &op_defs);
122 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
123 ApiDefMap api_def_map(op_defs);
124
125 string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
126
127 const string typed_bar =
128 "def bar(x: _ops.Tensor[_dtypes.String], y: _ops.Tensor[_dtypes.QInt8], "
129 "name=None) -> _ops.Tensor[_dtypes.Bool]:";
130 ExpectHasSubstr(code, typed_bar);
131
132 const string untyped_bar = "def bar(x, y, name=None):";
133 ExpectDoesNotHaveSubstr(code, untyped_bar);
134 }
135
TEST(PythonOpGen,TypeAnnotateMultiTypeTensor)136 TEST(PythonOpGen, TypeAnnotateMultiTypeTensor) {
137 constexpr char kBaseOpDef[] = R"(
138 op {
139 name: "Foo"
140 input_arg {
141 name: "x"
142 type_attr: "T"
143 }
144 input_arg {
145 name: "y"
146 type_attr: "T2"
147 }
148 output_arg {
149 name: "output"
150 type_attr: "T"
151 }
152 attr {
153 name: "T"
154 type: "type"
155 allowed_values {
156 list {
157 type: DT_UINT8
158 type: DT_INT8
159 }
160 }
161 }
162 attr {
163 name: "T2"
164 type: "type"
165 allowed_values {
166 list {
167 type: DT_STRING
168 type: DT_FLOAT
169 type: DT_DOUBLE
170 }
171 }
172 }
173 summary: "Summary for op Foo."
174 description: "Description for op Foo."
175 }
176 )";
177
178 std::unordered_set<string> type_annotate_ops{
179 "Foo",
180 };
181
182 OpList op_defs;
183 OpRegistry::Global()->Export(false, &op_defs);
184 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
185 ApiDefMap api_def_map(op_defs);
186
187 string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
188
189 const string typed_foo =
190 "def foo(x: _ops.Tensor[TV_Foo_T], y: _ops.Tensor[TV_Foo_T2], name=None) "
191 "-> _ops.Tensor[TV_Foo_T]:";
192 ExpectHasSubstr(code, typed_foo);
193 }
194
TEST(PythonOpGen,GenerateCorrectTypeVars)195 TEST(PythonOpGen, GenerateCorrectTypeVars) {
196 constexpr char kBaseOpDef[] = R"(
197 op {
198 name: "Foo"
199 input_arg {
200 name: "x"
201 type_attr: "T"
202 }
203 input_arg {
204 name: "y"
205 type_attr: "T2"
206 }
207 output_arg {
208 name: "output"
209 type_attr: "T"
210 }
211 attr {
212 name: "T"
213 type: "type"
214 allowed_values {
215 list {
216 type: DT_UINT8
217 type: DT_INT8
218 }
219 }
220 }
221 attr {
222 name: "T2"
223 type: "type"
224 allowed_values {
225 list {
226 type: DT_STRING
227 type: DT_FLOAT
228 type: DT_DOUBLE
229 }
230 }
231 }
232 summary: "Summary for op Foo."
233 description: "Description for op Foo."
234 }
235 )";
236
237 std::unordered_set<string> type_annotate_ops{
238 "Foo",
239 };
240
241 OpList op_defs;
242 OpRegistry::Global()->Export(false, &op_defs);
243 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
244 ApiDefMap api_def_map(op_defs);
245
246 string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
247
248 const string typevars_foo = R"(
249 TV_Foo_T = TypeVar("TV_Foo_T", _dtypes.Int8, _dtypes.UInt8)
250 TV_Foo_T2 = TypeVar("TV_Foo_T2", _dtypes.Float32, _dtypes.Float64, _dtypes.String)
251 )";
252
253 ExpectHasSubstr(code, typevars_foo);
254 }
255
TEST(PythonOpGen,TypeAnnotateFallback)256 TEST(PythonOpGen, TypeAnnotateFallback) {
257 constexpr char kBaseOpDef[] = R"(
258 op {
259 name: "Foo"
260 input_arg {
261 name: "x"
262 type_attr: "T"
263 }
264 input_arg {
265 name: "y"
266 type_attr: "T2"
267 }
268 output_arg {
269 name: "output"
270 type_attr: "T"
271 }
272 attr {
273 name: "T"
274 type: "type"
275 allowed_values {
276 list {
277 type: DT_UINT8
278 type: DT_INT8
279 }
280 }
281 }
282 attr {
283 name: "T2"
284 type: "type"
285 allowed_values {
286 list {
287 type: DT_STRING
288 type: DT_FLOAT
289 type: DT_DOUBLE
290 }
291 }
292 }
293 summary: "Summary for op Foo."
294 description: "Description for op Foo."
295 }
296 )";
297
298 std::unordered_set<string> type_annotate_ops{
299 "Foo",
300 };
301
302 OpList op_defs;
303 OpRegistry::Global()->Export(false, &op_defs);
304 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
305 ApiDefMap api_def_map(op_defs);
306
307 string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
308
309 const string typed_foo_fallback =
310 "def foo_eager_fallback(x: _ops.Tensor[TV_Foo_T], y: "
311 "_ops.Tensor[TV_Foo_T2], name, ctx) -> _ops.Tensor[TV_Foo_T]:";
312 ExpectHasSubstr(code, typed_foo_fallback);
313 }
314
TEST(PythonOpGen,GenerateTypeVarAboveOp)315 TEST(PythonOpGen, GenerateTypeVarAboveOp) {
316 constexpr char kBaseOpDef[] = R"(
317 op {
318 name: "Foo"
319 input_arg {
320 name: "x"
321 type_attr: "T"
322 }
323 input_arg {
324 name: "y"
325 type_attr: "T2"
326 }
327 output_arg {
328 name: "output"
329 type_attr: "T"
330 }
331 attr {
332 name: "T"
333 type: "type"
334 allowed_values {
335 list {
336 type: DT_UINT8
337 type: DT_INT8
338 }
339 }
340 }
341 attr {
342 name: "T2"
343 type: "type"
344 allowed_values {
345 list {
346 type: DT_STRING
347 type: DT_FLOAT
348 type: DT_DOUBLE
349 }
350 }
351 }
352 summary: "Summary for op Foo."
353 description: "Description for op Foo."
354 }
355 )";
356
357 std::unordered_set<string> type_annotate_ops{
358 "Foo",
359 };
360
361 OpList op_defs;
362 OpRegistry::Global()->Export(false, &op_defs);
363 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
364 ApiDefMap api_def_map(op_defs);
365
366 string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
367
368 const string typevar_foo = "TV_Foo_";
369 const string def_foo = "def foo";
370 ExpectSubstrOrder(code, typevar_foo, def_foo);
371 }
372
TEST(PythonOpGen,TypeAnnotateDefaultParams)373 TEST(PythonOpGen, TypeAnnotateDefaultParams) {
374 constexpr char kBaseOpDef[] = R"(
375 op {
376 name: "FooBar"
377 input_arg {
378 name: "x"
379 type: DT_FLOAT
380 }
381 output_arg {
382 name: "output"
383 type: DT_BOOL
384 }
385 attr {
386 name: "t"
387 type: "type"
388 allowed_values {
389 list {
390 type: DT_HALF
391 type: DT_INT8
392 }
393 }
394 }
395 attr {
396 name: "var1"
397 type: "bool"
398 default_value {
399 b: false
400 }
401 }
402 attr {
403 name: "var2"
404 type: "int"
405 default_value {
406 i: 0
407 }
408 }
409 summary: "Summary for op FooBar."
410 description: "Description for op FooBar."
411 }
412 )";
413
414 std::unordered_set<string> type_annotate_ops{
415 "FooBar",
416 };
417
418 OpList op_defs;
419 OpRegistry::Global()->Export(false, &op_defs);
420 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
421 ApiDefMap api_def_map(op_defs);
422
423 string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
424
425 const string params =
426 "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, "
427 "var1:bool=False, var2:int=0, name=None)";
428 const string params_fallback =
429 "def foo_bar_eager_fallback(x: _ops.Tensor[_dtypes.Float32], t: "
430 "TV_FooBar_t, var1: bool, var2: int, name, ctx)";
431 ExpectHasSubstr(code, params);
432 ExpectHasSubstr(code, params_fallback);
433 }
434
TEST(PythonOpGen,NoTypingSequenceTensors)435 TEST(PythonOpGen, NoTypingSequenceTensors) {
436 constexpr char kBaseOpDef[] = R"(
437 op {
438 name: "Baz"
439 input_arg {
440 name: "inputs"
441 number_attr: "N"
442 type_list_attr: "T"
443 }
444 output_arg {
445 name: "output1"
446 type: DT_BOOL
447 }
448 output_arg {
449 name: "output2"
450 type: DT_BOOL
451 }
452 attr {
453 name: "T"
454 type: "bool"
455 }
456 attr {
457 name: "N"
458 type: "int"
459 }
460 summary: "Summary for op Baz."
461 description: "Description for op Baz."
462 }
463 )";
464
465 std::unordered_set<string> type_annotate_ops{"Baz"};
466
467 OpList op_defs;
468 OpRegistry::Global()->Export(false, &op_defs);
469 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
470 ApiDefMap api_def_map(op_defs);
471
472 string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
473
474 const string baz_def_line = "def baz(inputs, name=None):";
475
476 ExpectHasSubstr(code, baz_def_line);
477 }
478
479 } // namespace
480 } // namespace tensorflow
481