• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/framework/python_op_gen.h"
17 
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/framework/op_gen_lib.h"
21 #include "tensorflow/core/platform/test.h"
22 
23 namespace tensorflow {
24 namespace {
25 
ExpectHasSubstr(const string & s,const string & expected)26 void ExpectHasSubstr(const string& s, const string& expected) {
27   EXPECT_TRUE(absl::StrContains(s, expected))
28       << "'Generated ops does not contain '" << expected << "'";
29 }
30 
ExpectDoesNotHaveSubstr(const string & s,const string & expected)31 void ExpectDoesNotHaveSubstr(const string& s, const string& expected) {
32   EXPECT_FALSE(absl::StrContains(s, expected))
33       << "'Generated ops contains '" << expected << "'";
34 }
35 
ExpectSubstrOrder(const string & s,const string & before,const string & after)36 void ExpectSubstrOrder(const string& s, const string& before,
37                        const string& after) {
38   int before_pos = s.find(before);
39   int after_pos = s.find(after);
40   ASSERT_NE(std::string::npos, before_pos);
41   ASSERT_NE(std::string::npos, after_pos);
42   EXPECT_LT(before_pos, after_pos) << before << "' is not before '" << after;
43 }
44 
TEST(PythonOpGen,TypeAnnotateAllOps)45 TEST(PythonOpGen, TypeAnnotateAllOps) {
46   OpList ops;
47   OpRegistry::Global()->Export(false, &ops);
48 
49   ApiDefMap api_def_map(ops);
50 
51   std::unordered_set<string> type_annotate_ops;
52   for (const auto& op : ops.op()) {
53     type_annotate_ops.insert(op.name());
54   }
55 
56   string code = GetPythonOps(ops, api_def_map, {}, "", type_annotate_ops);
57 
58   const string all_types =
59       ", _dtypes.BFloat16, _dtypes.Bool, _dtypes.Complex128, "
60       "_dtypes.Complex64, "
61       "_dtypes.Float16, _dtypes.Float32, _dtypes.Float64, _dtypes.Half, "
62       "_dtypes.Int16, "
63       "_dtypes.Int32, _dtypes.Int64, _dtypes.Int8, _dtypes.QInt16, "
64       "_dtypes.QInt32, "
65       "_dtypes.QInt8, _dtypes.QUInt16, _dtypes.QUInt8, _dtypes.Resource, "
66       "_dtypes.String, "
67       "_dtypes.UInt16, _dtypes.UInt32, _dtypes.UInt64, _dtypes.UInt8, "
68       "_dtypes.Variant)";
69 
70   const string fake_param_typevar =
71       "TV_FakeParam_dtype = TypeVar(\"TV_FakeParam_dtype\"" + all_types;
72   const string fake_param =
73       "def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, "
74       "ctx) -> _ops.Tensor[TV_FakeParam_dtype]:";
75   const string fake_param_fallback =
76       "def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, "
77       "ctx) -> _ops.Tensor[TV_FakeParam_dtype]:";
78 
79   ExpectHasSubstr(code, fake_param_typevar);
80   ExpectHasSubstr(code, fake_param);
81   ExpectHasSubstr(code, fake_param_fallback);
82 
83   const string to_bool_typevar =
84       "TV_ToBool_T = TypeVar(\"TV_ToBool_T\"" + all_types;
85   const string to_bool_ =
86       "def to_bool(input: _ops.Tensor[TV_ToBool_T], name=None) -> "
87       "_ops.Tensor[_dtypes.Bool]:";
88   const string to_bool_fallback =
89       "def to_bool_eager_fallback(input: _ops.Tensor[TV_ToBool_T], name, ctx) "
90       "-> _ops.Tensor[_dtypes.Bool]:";
91 
92   ExpectHasSubstr(code, to_bool_typevar);
93   ExpectHasSubstr(code, to_bool_);
94   ExpectHasSubstr(code, to_bool_fallback);
95 }
96 
TEST(PythonOpGen,TypeAnnotateSingleTypeTensor)97 TEST(PythonOpGen, TypeAnnotateSingleTypeTensor) {
98   constexpr char kBaseOpDef[] = R"(
99   op {
100     name: "Bar"
101     input_arg {
102       name: "x"
103       type: DT_STRING
104     }
105     input_arg {
106       name: "y"
107       type: DT_QINT8
108     }
109     output_arg {
110       name: "output"
111       type: DT_BOOL
112     }
113     summary: "Summary for op Bar."
114     description: "Description for op Bar."
115   }
116   )";
117 
118   std::unordered_set<string> type_annotate_ops{"Bar"};
119 
120   OpList op_defs;
121   OpRegistry::Global()->Export(false, &op_defs);
122   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
123   ApiDefMap api_def_map(op_defs);
124 
125   string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
126 
127   const string typed_bar =
128       "def bar(x: _ops.Tensor[_dtypes.String], y: _ops.Tensor[_dtypes.QInt8], "
129       "name=None) -> _ops.Tensor[_dtypes.Bool]:";
130   ExpectHasSubstr(code, typed_bar);
131 
132   const string untyped_bar = "def bar(x, y, name=None):";
133   ExpectDoesNotHaveSubstr(code, untyped_bar);
134 }
135 
TEST(PythonOpGen,TypeAnnotateMultiTypeTensor)136 TEST(PythonOpGen, TypeAnnotateMultiTypeTensor) {
137   constexpr char kBaseOpDef[] = R"(
138   op {
139     name: "Foo"
140     input_arg {
141       name: "x"
142       type_attr: "T"
143     }
144     input_arg {
145       name: "y"
146       type_attr: "T2"
147     }
148     output_arg {
149       name: "output"
150       type_attr: "T"
151     }
152     attr {
153       name: "T"
154       type: "type"
155       allowed_values {
156         list {
157           type: DT_UINT8
158           type: DT_INT8
159         }
160       }
161     }
162     attr {
163       name: "T2"
164       type: "type"
165       allowed_values {
166         list {
167           type: DT_STRING
168           type: DT_FLOAT
169           type: DT_DOUBLE
170         }
171       }
172     }
173     summary: "Summary for op Foo."
174     description: "Description for op Foo."
175   }
176   )";
177 
178   std::unordered_set<string> type_annotate_ops{
179       "Foo",
180   };
181 
182   OpList op_defs;
183   OpRegistry::Global()->Export(false, &op_defs);
184   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
185   ApiDefMap api_def_map(op_defs);
186 
187   string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
188 
189   const string typed_foo =
190       "def foo(x: _ops.Tensor[TV_Foo_T], y: _ops.Tensor[TV_Foo_T2], name=None) "
191       "-> _ops.Tensor[TV_Foo_T]:";
192   ExpectHasSubstr(code, typed_foo);
193 }
194 
TEST(PythonOpGen,GenerateCorrectTypeVars)195 TEST(PythonOpGen, GenerateCorrectTypeVars) {
196   constexpr char kBaseOpDef[] = R"(
197   op {
198     name: "Foo"
199     input_arg {
200       name: "x"
201       type_attr: "T"
202     }
203     input_arg {
204       name: "y"
205       type_attr: "T2"
206     }
207     output_arg {
208       name: "output"
209       type_attr: "T"
210     }
211     attr {
212       name: "T"
213       type: "type"
214       allowed_values {
215         list {
216           type: DT_UINT8
217           type: DT_INT8
218         }
219       }
220     }
221     attr {
222       name: "T2"
223       type: "type"
224       allowed_values {
225         list {
226           type: DT_STRING
227           type: DT_FLOAT
228           type: DT_DOUBLE
229         }
230       }
231     }
232     summary: "Summary for op Foo."
233     description: "Description for op Foo."
234   }
235   )";
236 
237   std::unordered_set<string> type_annotate_ops{
238       "Foo",
239   };
240 
241   OpList op_defs;
242   OpRegistry::Global()->Export(false, &op_defs);
243   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
244   ApiDefMap api_def_map(op_defs);
245 
246   string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
247 
248   const string typevars_foo = R"(
249 TV_Foo_T = TypeVar("TV_Foo_T", _dtypes.Int8, _dtypes.UInt8)
250 TV_Foo_T2 = TypeVar("TV_Foo_T2", _dtypes.Float32, _dtypes.Float64, _dtypes.String)
251 )";
252 
253   ExpectHasSubstr(code, typevars_foo);
254 }
255 
TEST(PythonOpGen,TypeAnnotateFallback)256 TEST(PythonOpGen, TypeAnnotateFallback) {
257   constexpr char kBaseOpDef[] = R"(
258   op {
259     name: "Foo"
260     input_arg {
261       name: "x"
262       type_attr: "T"
263     }
264     input_arg {
265       name: "y"
266       type_attr: "T2"
267     }
268     output_arg {
269       name: "output"
270       type_attr: "T"
271     }
272     attr {
273       name: "T"
274       type: "type"
275       allowed_values {
276         list {
277           type: DT_UINT8
278           type: DT_INT8
279         }
280       }
281     }
282     attr {
283       name: "T2"
284       type: "type"
285       allowed_values {
286         list {
287           type: DT_STRING
288           type: DT_FLOAT
289           type: DT_DOUBLE
290         }
291       }
292     }
293     summary: "Summary for op Foo."
294     description: "Description for op Foo."
295   }
296   )";
297 
298   std::unordered_set<string> type_annotate_ops{
299       "Foo",
300   };
301 
302   OpList op_defs;
303   OpRegistry::Global()->Export(false, &op_defs);
304   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
305   ApiDefMap api_def_map(op_defs);
306 
307   string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
308 
309   const string typed_foo_fallback =
310       "def foo_eager_fallback(x: _ops.Tensor[TV_Foo_T], y: "
311       "_ops.Tensor[TV_Foo_T2], name, ctx) -> _ops.Tensor[TV_Foo_T]:";
312   ExpectHasSubstr(code, typed_foo_fallback);
313 }
314 
TEST(PythonOpGen,GenerateTypeVarAboveOp)315 TEST(PythonOpGen, GenerateTypeVarAboveOp) {
316   constexpr char kBaseOpDef[] = R"(
317   op {
318     name: "Foo"
319     input_arg {
320       name: "x"
321       type_attr: "T"
322     }
323     input_arg {
324       name: "y"
325       type_attr: "T2"
326     }
327     output_arg {
328       name: "output"
329       type_attr: "T"
330     }
331     attr {
332       name: "T"
333       type: "type"
334       allowed_values {
335         list {
336           type: DT_UINT8
337           type: DT_INT8
338         }
339       }
340     }
341     attr {
342       name: "T2"
343       type: "type"
344       allowed_values {
345         list {
346           type: DT_STRING
347           type: DT_FLOAT
348           type: DT_DOUBLE
349         }
350       }
351     }
352     summary: "Summary for op Foo."
353     description: "Description for op Foo."
354   }
355   )";
356 
357   std::unordered_set<string> type_annotate_ops{
358       "Foo",
359   };
360 
361   OpList op_defs;
362   OpRegistry::Global()->Export(false, &op_defs);
363   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
364   ApiDefMap api_def_map(op_defs);
365 
366   string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
367 
368   const string typevar_foo = "TV_Foo_";
369   const string def_foo = "def foo";
370   ExpectSubstrOrder(code, typevar_foo, def_foo);
371 }
372 
TEST(PythonOpGen,TypeAnnotateDefaultParams)373 TEST(PythonOpGen, TypeAnnotateDefaultParams) {
374   constexpr char kBaseOpDef[] = R"(
375   op {
376     name: "FooBar"
377     input_arg {
378       name: "x"
379       type: DT_FLOAT
380     }
381     output_arg {
382       name: "output"
383       type: DT_BOOL
384     }
385     attr {
386       name: "t"
387       type: "type"
388       allowed_values {
389         list {
390           type: DT_HALF
391           type: DT_INT8
392         }
393       }
394     }
395     attr {
396       name: "var1"
397       type: "bool"
398       default_value {
399         b: false
400       }
401     }
402     attr {
403       name: "var2"
404       type: "int"
405       default_value {
406         i: 0
407       }
408     }
409     summary: "Summary for op FooBar."
410     description: "Description for op FooBar."
411   }
412   )";
413 
414   std::unordered_set<string> type_annotate_ops{
415       "FooBar",
416   };
417 
418   OpList op_defs;
419   OpRegistry::Global()->Export(false, &op_defs);
420   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
421   ApiDefMap api_def_map(op_defs);
422 
423   string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
424 
425   const string params =
426       "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, "
427       "var1:bool=False, var2:int=0, name=None)";
428   const string params_fallback =
429       "def foo_bar_eager_fallback(x: _ops.Tensor[_dtypes.Float32], t: "
430       "TV_FooBar_t, var1: bool, var2: int, name, ctx)";
431   ExpectHasSubstr(code, params);
432   ExpectHasSubstr(code, params_fallback);
433 }
434 
TEST(PythonOpGen,NoTypingSequenceTensors)435 TEST(PythonOpGen, NoTypingSequenceTensors) {
436   constexpr char kBaseOpDef[] = R"(
437   op {
438     name: "Baz"
439     input_arg {
440       name: "inputs"
441       number_attr: "N"
442       type_list_attr: "T"
443     }
444     output_arg {
445       name: "output1"
446       type: DT_BOOL
447     }
448     output_arg {
449       name: "output2"
450       type: DT_BOOL
451     }
452     attr {
453       name: "T"
454       type: "bool"
455     }
456     attr {
457       name: "N"
458       type: "int"
459     }
460     summary: "Summary for op Baz."
461     description: "Description for op Baz."
462   }
463   )";
464 
465   std::unordered_set<string> type_annotate_ops{"Baz"};
466 
467   OpList op_defs;
468   OpRegistry::Global()->Export(false, &op_defs);
469   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
470   ApiDefMap api_def_map(op_defs);
471 
472   string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
473 
474   const string baz_def_line = "def baz(inputs, name=None):";
475 
476   ExpectHasSubstr(code, baz_def_line);
477 }
478 
479 }  // namespace
480 }  // namespace tensorflow
481