1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/framework/python_op_gen_internal.h"
17
18 #include <float.h>
19 #include <stdio.h>
20 #include <iomanip>
21 #include <sstream>
22 #include <unordered_map>
23 #include "tensorflow/core/framework/api_def.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb_text.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/op_gen_lib.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor.pb_text.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/lib/gtl/stl_util.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/lib/strings/stringprintf.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace tensorflow {
45 namespace python_op_gen_internal {
46
47 const int kRightMargin = 78;
48 // Names specified in tf_export decorators are exported to
49 // TensorFlow 2.0 by default.
50 const int kLatestAPIExportVersion = 2;
51
IsPythonReserved(const string & s)52 bool IsPythonReserved(const string& s) {
53 static const std::set<string>* const kPythonReserved = new std::set<string>(
54 {// Keywords in Python, from:
55 // import keyword
56 // print keyword.kwlist
57 "and", "as", "assert", "break", "class", "continue", "def", "del",
58 "elif", "else", "except", "exec", "finally", "for", "from", "global",
59 "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
60 "raise", "return", "try", "while", "with", "yield",
61 // Built-in functions and types in Python, from:
62 // [x for x in dir(__builtins__) if not x[0].islower()]
63 "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
64 "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
65 "Ellipsis", "EnvironmentError", "Exception", "False",
66 "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
67 "ImportError", "ImportWarning", "IndentationError", "IndexError",
68 "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
69 "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
70 "OverflowError", "PendingDeprecationWarning", "ReferenceError",
71 "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
72 "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
73 "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
74 "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
75 "UnicodeWarning", "UserWarning", "ValueError", "Warning",
76 "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
77 "__package__"});
78
79 return kPythonReserved->count(s) > 0;
80 }
81
IsOpWithUnderscorePrefix(const string & s)82 bool IsOpWithUnderscorePrefix(const string& s) {
83 static const std::set<string>* const kUnderscoreOps = new std::set<string>(
84 {// Lowercase built-in functions and types in Python, from:
85 // [x for x in dir(__builtins__) if x[0].islower()] except "round".
86 // These need to be excluded so they don't conflict with actual built-in
87 // functions since we use '*' imports.
88 "abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray",
89 "bytes", "callable", "chr", "classmethod", "cmp", "coerce", "compile",
90 "complex", "copyright", "credits", "delattr", "dict", "dir", "divmod",
91 "enumerate", "eval", "execfile", "exit", "file", "filter", "float",
92 "format", "frozenset", "getattr", "globals", "hasattr", "hash", "help",
93 "hex", "id", "input", "int", "intern", "isinstance", "issubclass",
94 "iter", "len", "license", "list", "locals", "long", "map", "max",
95 "memoryview", "min", "next", "object", "oct", "open", "ord", "pow",
96 "print", "property", "quit", "range", "raw_input", "reduce", "reload",
97 "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod",
98 "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars",
99 "xrange", "zip",
100 // These have the same name as ops defined in Python and might be used
101 // incorrectly depending on order of '*' imports.
102 // TODO(annarev): reduce usage of '*' imports and remove these from the
103 // list.
104 "fused_batch_norm", "histogram_fixed_width", "stack",
105 "batch_norm_with_global_normalization", "clip_by_value"});
106 return kUnderscoreOps->count(s) > 0;
107 }
108
AvoidPythonReserved(const string & s)109 string AvoidPythonReserved(const string& s) {
110 if (IsPythonReserved(s)) return strings::StrCat(s, "_");
111 return s;
112 }
113
114 // Indent the first line by "initial" spaces and all following lines
115 // by "rest" spaces.
Indent(int initial,int rest,StringPiece in)116 string Indent(int initial, int rest, StringPiece in) {
117 // TODO(josh11b): Also word-wrapping?
118 string copy(in.data(), in.size());
119 str_util::StripTrailingWhitespace(©);
120 std::vector<string> v = str_util::Split(copy, '\n');
121
122 string result;
123 bool first = true;
124 for (const string& line : v) {
125 if (first) {
126 result = strings::StrCat(Spaces(initial), line, "\n");
127 first = false;
128 } else {
129 if (line.empty()) {
130 strings::StrAppend(&result, "\n");
131 } else {
132 strings::StrAppend(&result, Spaces(rest), line, "\n");
133 }
134 }
135 }
136 return result;
137 }
138
139 // Adds append to *dest, with a space if the first line will be <= width,
140 // or a newline otherwise.
AppendWithinWidth(string * dest,StringPiece append,int width)141 void AppendWithinWidth(string* dest, StringPiece append, int width) {
142 auto first_line = append.find('\n');
143 if (first_line == string::npos) first_line = append.size();
144 if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
145 strings::StrAppend(dest, "\n", append);
146 } else {
147 strings::StrAppend(dest, " ", append);
148 }
149 }
150
151 // Like DataTypeString() but uses the Python names for the
152 // float types.
PythonDataTypeString(DataType dtype)153 string PythonDataTypeString(DataType dtype) {
154 switch (dtype) {
155 case DT_FLOAT:
156 return "float32";
157 case DT_DOUBLE:
158 return "float64";
159 default:
160 return DataTypeString(dtype);
161 }
162 }
163
TypeString(DataType dtype,bool ref)164 string TypeString(DataType dtype, bool ref) {
165 if (ref) {
166 return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
167 } else {
168 return strings::StrCat("`", PythonDataTypeString(dtype), "`");
169 }
170 }
171
TypeListString(const AttrValue & value)172 string TypeListString(const AttrValue& value) {
173 string ret;
174 for (int t : value.list().type()) {
175 if (!ret.empty()) strings::StrAppend(&ret, ", ");
176 DataType dtype = static_cast<DataType>(t);
177 if (IsRefType(dtype)) {
178 strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
179 " mutable");
180 } else {
181 strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
182 }
183 }
184 return ret;
185 }
186
SingleTensorName(DataType dtype,bool is_ref)187 string SingleTensorName(DataType dtype, bool is_ref) {
188 const string type_str = TypeString(dtype, is_ref);
189 return strings::StrCat("A `Tensor` of type ", type_str, ".");
190 }
191
192 const char kUnknownTensorType[] = {"A `Tensor`."};
193
ArgTypeName(const OpDef & op_def,const OpDef::ArgDef & arg,const std::unordered_map<string,string> & inferred_attrs,bool is_output)194 string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
195 const std::unordered_map<string, string>& inferred_attrs,
196 bool is_output) {
197 if (!arg.number_attr().empty()) {
198 // N Tensors with the same type
199 const string* original_arg =
200 gtl::FindOrNull(inferred_attrs, arg.number_attr());
201 string prefix;
202 if (original_arg == nullptr) {
203 prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
204 } else if (*original_arg == arg.name()) {
205 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
206 if (attr->has_minimum() && attr->minimum() > 0) {
207 prefix = strings::StrCat("A list of at least ", attr->minimum());
208 } else {
209 prefix = "A list of";
210 }
211 } else {
212 prefix = strings::StrCat("A list with the same length as `",
213 AvoidPythonReserved(*original_arg), "` of");
214 }
215
216 if (arg.type() != DT_INVALID) {
217 return strings::StrCat(prefix, " `Tensor` objects with type ",
218 TypeString(arg.type(), arg.is_ref()), ".");
219 } else {
220 original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
221 if (arg.is_ref()) {
222 strings::StrAppend(&prefix, " mutable");
223 }
224 if (original_arg == nullptr) {
225 return strings::StrCat(prefix, " `Tensor` objects with type `",
226 arg.type_attr(), "`.");
227 } else if (*original_arg == arg.name()) {
228 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
229 if (attr->has_allowed_values()) {
230 return strings::StrCat(prefix,
231 " `Tensor` objects with the same type in: ",
232 TypeListString(attr->allowed_values()), ".");
233 } else {
234 return strings::StrCat(prefix,
235 " `Tensor` objects with the same type.");
236 }
237 } else {
238 return strings::StrCat(prefix,
239 " `Tensor` objects with the same type as `",
240 AvoidPythonReserved(*original_arg), "`.");
241 }
242 }
243 } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
244 const bool is_list = !arg.type_list_attr().empty();
245 const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
246 const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
247 const string mutable_str = arg.is_ref() ? "mutable " : "";
248 const string prefix =
249 is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
250 : strings::StrCat("A ", mutable_str, "`Tensor`");
251 const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
252 if (original_arg == nullptr) {
253 return strings::StrCat(prefix, " of type `", attr_name, "`.");
254 } else if (*original_arg == arg.name()) {
255 if (attr->has_allowed_values()) {
256 if (is_list) {
257 return strings::StrCat(prefix, " with types from: ",
258 TypeListString(attr->allowed_values()), ".");
259 } else {
260 return strings::StrCat(
261 prefix, is_output ? ". Has one of the following types: "
262 : ". Must be one of the following types: ",
263 TypeListString(attr->allowed_values()), ".");
264 }
265 } else {
266 return strings::StrCat(prefix, ".");
267 }
268 } else {
269 return strings::StrCat(prefix,
270 is_output ? ". Has the same type as `"
271 : ". Must have the same type as `",
272 AvoidPythonReserved(*original_arg), "`.");
273 }
274 } else {
275 return SingleTensorName(arg.type(), arg.is_ref());
276 }
277 }
278
GetReturns(const OpDef & op_def,const std::vector<string> & output_type_string)279 string GetReturns(const OpDef& op_def,
280 const std::vector<string>& output_type_string) {
281 string result;
282 DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
283 const int num_outs = op_def.output_arg_size();
284 strings::StrAppend(&result, "\n Returns:\n");
285 if (num_outs == 0) {
286 strings::StrAppend(&result, " The created Operation.\n");
287 } else {
288 if (num_outs == 1) {
289 StringPiece description = op_def.output_arg(0).description();
290 if (ConsumeEquals(&description)) { // Skip the generated type info.
291 strings::StrAppend(&result, Indent(4, 4, description));
292 } else {
293 // Special case of one output, don't use the name of the output unless
294 // there is no description.
295 string desc = output_type_string.empty() ? kUnknownTensorType
296 : output_type_string[0];
297 if (desc == kUnknownTensorType) {
298 // Special case where we don't understand how the output tensor type
299 // depends on the input tensor types, just use the output arg
300 // description if we can.
301 if (!description.empty()) {
302 desc = op_def.output_arg(0).description();
303 } else if (!op_def.output_arg(0).name().empty()) {
304 desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
305 " `Tensor`.");
306 }
307 } else if (!description.empty()) {
308 AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
309 }
310 strings::StrAppend(&result, Indent(4, 4, desc));
311 }
312 } else {
313 std::vector<string> out_names(num_outs);
314 for (int i = 0; i < num_outs; ++i) {
315 if (!op_def.output_arg(i).name().empty()) {
316 out_names[i] = op_def.output_arg(i).name();
317 } else {
318 out_names[i] = strings::StrCat("output", i);
319 }
320 }
321 strings::StrAppend(&result, " A tuple of `Tensor` objects (",
322 str_util::Join(out_names, ", "), ").\n\n");
323 for (int i = 0; i < num_outs; ++i) {
324 string desc = strings::StrCat(out_names[i], ": ");
325 StringPiece description = op_def.output_arg(i).description();
326 if (ConsumeEquals(&description)) { // Skip the generated type info.
327 strings::StrAppend(&desc, description);
328 } else {
329 const string type = static_cast<size_t>(i) < output_type_string.size()
330 ? output_type_string[i]
331 : kUnknownTensorType;
332 if (!description.empty()) {
333 if (type == kUnknownTensorType) {
334 // Special case where we don't understand how the output tensor
335 // type depends on the input tensor types, so we just use the
336 // output arg description.
337 strings::StrAppend(&desc, description);
338 } else {
339 strings::StrAppend(&desc, type, " ", description);
340 }
341 } else {
342 strings::StrAppend(&desc, type);
343 }
344 }
345 strings::StrAppend(&result, Indent(4, 6, desc));
346 }
347 }
348 }
349 return result;
350 }
351
StringToPython(const string & str)352 string StringToPython(const string& str) {
353 return strings::StrCat("\"", str_util::CEscape(str), "\"");
354 }
355
DataTypeToPython(DataType dtype,const string & dtype_module)356 string DataTypeToPython(DataType dtype, const string& dtype_module) {
357 return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
358 }
359
ShapeToPython(const TensorShapeProto & shape)360 string ShapeToPython(const TensorShapeProto& shape) {
361 if (shape.unknown_rank()) {
362 return "None";
363 }
364 string python = "[";
365 for (const auto& dim : shape.dim()) {
366 if (python.size() > 1) strings::StrAppend(&python, ", ");
367 if (!dim.name().empty()) {
368 strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
369 dim.size(), ")");
370 } else {
371 strings::StrAppend(&python, dim.size());
372 }
373 }
374 strings::StrAppend(&python, "]");
375 return python;
376 }
377
TensorToPython(const TensorProto & proto)378 string TensorToPython(const TensorProto& proto) {
379 return ProtoShortDebugString(proto);
380 }
381
AttrListToPython(const AttrValue & value,const string & dtype_module="tf.")382 string AttrListToPython(const AttrValue& value,
383 const string& dtype_module = "tf.") {
384 string ret;
385 if (value.list().s_size() > 0) {
386 for (int i = 0; i < value.list().s_size(); ++i) {
387 if (i > 0) strings::StrAppend(&ret, ", ");
388 strings::StrAppend(&ret, StringToPython(value.list().s(i)));
389 }
390 } else if (value.list().i_size() > 0) {
391 for (int i = 0; i < value.list().i_size(); ++i) {
392 if (i > 0) strings::StrAppend(&ret, ", ");
393 strings::StrAppend(&ret, value.list().i(i));
394 }
395 } else if (value.list().f_size() > 0) {
396 for (int i = 0; i < value.list().f_size(); ++i) {
397 if (i > 0) strings::StrAppend(&ret, ", ");
398 strings::StrAppend(&ret, value.list().f(i));
399 }
400 } else if (value.list().b_size() > 0) {
401 for (int i = 0; i < value.list().b_size(); ++i) {
402 if (i > 0) strings::StrAppend(&ret, ", ");
403 strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
404 }
405 } else if (value.list().type_size() > 0) {
406 for (int i = 0; i < value.list().type_size(); ++i) {
407 if (i > 0) strings::StrAppend(&ret, ", ");
408 strings::StrAppend(&ret,
409 DataTypeToPython(value.list().type(i), dtype_module));
410 }
411 } else if (value.list().shape_size() > 0) {
412 for (int i = 0; i < value.list().shape_size(); ++i) {
413 if (i > 0) strings::StrAppend(&ret, ", ");
414 strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
415 }
416 } else if (value.list().tensor_size() > 0) {
417 for (int i = 0; i < value.list().tensor_size(); ++i) {
418 if (i > 0) strings::StrAppend(&ret, ", ");
419 strings::StrAppend(&ret, TensorToPython(value.list().tensor(i)));
420 }
421 } else if (value.list().func_size() > 0) {
422 for (int i = 0; i < value.list().func_size(); ++i) {
423 if (i > 0) strings::StrAppend(&ret, ", ");
424 strings::StrAppend(&ret, StringToPython(value.list().func(i).name()));
425 }
426 }
427 return ret;
428 }
429
430 // NOTE: The return value may contain spaces (for example, it could be
431 // a string "foo bar" with an embedded space) and is not safe to pass
432 // to WordWrap().
AttrValueToPython(const string & type,const AttrValue & value,const string & dtype_module)433 string AttrValueToPython(const string& type, const AttrValue& value,
434 const string& dtype_module) {
435 if (type == "string") {
436 return StringToPython(value.s());
437 } else if (type == "int") {
438 return strings::StrCat(value.i());
439 } else if (type == "float") {
440 if (std::isnan(value.f()) || std::isinf(value.f())) {
441 return strings::StrCat("float('", value.f(), "')");
442 } else {
443 // Use locale-independent conversion.
444 static_assert(FLT_DIG < 10, "FLT_DIG is too big");
445 std::ostringstream s;
446 s.imbue(std::locale::classic());
447 s << std::setprecision(FLT_DIG) << value.f();
448 return s.str();
449 }
450 } else if (type == "bool") {
451 return value.b() ? "True" : "False";
452 } else if (type == "type") {
453 return DataTypeToPython(value.type(), dtype_module);
454 } else if (type == "shape") {
455 return ShapeToPython(value.shape());
456 } else if (type == "tensor") {
457 return TensorToPython(value.tensor());
458 } else if (type == "func") {
459 return StringToPython(value.func().name());
460 } else if (str_util::StartsWith(type, "list(")) {
461 return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
462 } else {
463 return "?";
464 }
465 }
466
GenerateLowerCaseOpName(const string & str,string * result)467 void GenerateLowerCaseOpName(const string& str, string* result) {
468 const char joiner = '_';
469 const int last_index = str.size() - 1;
470 for (int i = 0; i <= last_index; ++i) {
471 const char c = str[i];
472 // Emit a joiner only if a previous-lower-to-now-upper or a
473 // now-upper-to-next-lower transition happens.
474 if (isupper(c) && (i > 0)) {
475 if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
476 result->push_back(joiner);
477 }
478 }
479 result->push_back(tolower(c));
480 }
481 }
482
AddDelimiter(string * append_to,const string & delim)483 static void AddDelimiter(string* append_to, const string& delim) {
484 if (!append_to->empty()) strings::StrAppend(append_to, delim);
485 }
486
FindAttr(StringPiece name,const ApiDef & api_def)487 const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
488 for (int i = 0; i < api_def.attr_size(); ++i) {
489 if (api_def.attr(i).name() == name) {
490 return &api_def.attr(i);
491 }
492 }
493 return nullptr;
494 }
495
GenPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name)496 GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
497 const string& function_name)
498 : op_def_(op_def),
499 api_def_(api_def),
500 function_name_(function_name),
501 num_outs_(op_def.output_arg_size()) {}
502
~GenPythonOp()503 GenPythonOp::~GenPythonOp() {}
504
Code()505 string GenPythonOp::Code() {
506 // This has all the input args followed by those attrs that don't have
507 // defaults.
508 std::vector<ParamNames> params_no_default;
509 // The parameters with defaults (these have to be listed after those without).
510 // No input args are included, just attrs.
511 std::vector<ParamNames> params_with_default;
512
513 for (int i = 0; i < api_def_.arg_order_size(); ++i) {
514 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
515 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
516 params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
517 if (!arg.type_attr().empty()) {
518 gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
519 } else if (!arg.type_list_attr().empty()) {
520 gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
521 arg.name());
522 }
523 if (!arg.number_attr().empty()) {
524 gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
525 }
526 }
527 for (int i = 0; i < api_def_.attr_size(); ++i) {
528 const auto& attr(api_def_.attr(i));
529 // Do not add inferred attrs to the Python function signature.
530 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
531 if (attr.has_default_value()) {
532 params_with_default.emplace_back(attr.name(), attr.rename_to());
533 } else {
534 params_no_default.emplace_back(attr.name(), attr.rename_to());
535 }
536 }
537 }
538
539 // Save the list of attr parameters (attrs that won't be inferred),
540 // those with defaults go at the end.
541 // Get the attrs in the order we want by taking the attrs without defaults
542 // from the end of args_no_default, and adding args_no_default.
543 attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
544 params_with_default.size());
545 for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
546 attrs_.push_back(params_no_default[i].GetName());
547 }
548 for (int i = 0; i < params_with_default.size(); ++i) {
549 attrs_.push_back(params_with_default[i].GetName());
550 }
551
552 param_names_.reserve(params_no_default.size() + params_with_default.size());
553 param_names_.insert(param_names_.begin(), params_no_default.begin(),
554 params_no_default.end());
555 for (const auto& param : params_with_default) {
556 param_names_.push_back(param);
557 }
558
559 string parameters;
560 for (const auto& param : params_no_default) {
561 AddDelimiter(¶meters, ", ");
562 strings::StrAppend(¶meters, param.GetRenameTo());
563 }
564 for (const auto& param_and_default : params_with_default) {
565 AddDelimiter(¶meters, ", ");
566 strings::StrAppend(¶meters, param_and_default.GetRenameTo(), "=None");
567 }
568 AddDelimiter(¶meters, ", ");
569 strings::StrAppend(¶meters, "name=None");
570
571 AddExport();
572 AddDefLine(parameters);
573 AddDocStringDescription();
574 AddDocStringArgs();
575 AddDocStringInputs();
576 AddDocStringAttrs();
577 AddDocStringNameArg();
578 AddOutputGlobals();
579 AddDocStringOutputs();
580 strings::StrAppend(&result_, " \"\"\"\n");
581 AddBody(" ");
582 strings::StrAppend(&result_, "\n\n");
583
584 return prelude_ + result_;
585 }
586
AddExport()587 void GenPythonOp::AddExport() {
588 if (api_def_.visibility() != ApiDef::VISIBLE) {
589 return;
590 }
591 // Whether op should be available in latest export version.
592 bool op_available_in_latest =
593 !api_def_.deprecation_version() ||
594 api_def_.deprecation_version() > kLatestAPIExportVersion;
595
596 string names;
597 string names_v1;
598 string deprecated_endpoints;
599
600 for (const auto& endpoint : api_def_.endpoint()) {
601 string endpoint_name;
602 python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(),
603 &endpoint_name);
604 if (endpoint.deprecated() || endpoint.deprecation_version() > 0) {
605 AddDelimiter(&deprecated_endpoints, ", ");
606 strings::StrAppend(&deprecated_endpoints, "'", endpoint_name, "'");
607 }
608 // Add all endpoints to TensorFlow 1.* API.
609 AddDelimiter(&names_v1, ", ");
610 strings::StrAppend(&names_v1, "'", endpoint_name, "'");
611 // Add non-deprecated endpoints to TensorFlow 2.* API.
612 if (op_available_in_latest &&
613 (!endpoint.deprecation_version() ||
614 endpoint.deprecation_version() > kLatestAPIExportVersion)) {
615 AddDelimiter(&names, ", ");
616 strings::StrAppend(&names, "'", endpoint_name, "'");
617 }
618 }
619
620 // tf_export decorator has the following format:
621 // @tf_export(v2_name, v2_name, v1=[v1_name, v1_name])
622 if (names != names_v1) {
623 AddDelimiter(&names, ", ");
624 strings::StrAppend(&names, "v1=[", names_v1, "]");
625 }
626 strings::StrAppend(&result_, "@tf_export(", names, ")\n");
627
628 // If all endpoints are deprecated, add @deprecated decorator.
629 if (!api_def_.deprecation_message().empty()) {
630 const string instructions = api_def_.deprecation_message();
631 strings::StrAppend(&result_, "@deprecated(None, '", instructions, "')\n");
632 }
633 // Add @deprecated_endpoints decorator.
634 if (!deprecated_endpoints.empty()) {
635 strings::StrAppend(&result_, "@deprecated_endpoints(", deprecated_endpoints,
636 ")\n");
637 }
638 }
639
AddDefLine(const string & function_name,const string & parameters)640 void GenPythonOp::AddDefLine(const string& function_name,
641 const string& parameters) {
642 strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n");
643 }
644
AddDefLine(const string & parameters)645 void GenPythonOp::AddDefLine(const string& parameters) {
646 AddDefLine(function_name_, parameters);
647 }
648
AddDocStringDescription()649 void GenPythonOp::AddDocStringDescription() {
650 string comment;
651 if (api_def_.summary().empty()) {
652 comment = "TODO: add doc.\n";
653 } else {
654 comment = strings::StrCat(api_def_.summary(), "\n");
655 if (!api_def_.description().empty()) {
656 strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description()));
657 }
658 }
659 strings::StrAppend(&result_, " r\"\"\"", comment, "\n");
660 }
661
AddDocStringArgs()662 void GenPythonOp::AddDocStringArgs() {
663 strings::StrAppend(&result_, " Args:\n");
664 }
665
AddDocStringInputs()666 void GenPythonOp::AddDocStringInputs() {
667 for (int i = 0; i < api_def_.arg_order_size(); ++i) {
668 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
669 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
670 StringPiece description = api_def_arg.description();
671 string desc;
672 if (ConsumeEquals(&description)) { // Skip the generated type info.
673 desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
674 } else {
675 desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
676 ArgTypeName(op_def_, arg, inferred_attrs_, false));
677 }
678 if (!description.empty()) {
679 AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
680 }
681 strings::StrAppend(&result_, Indent(4, 6, desc));
682 }
683 }
684
AddDocStringAttrs()685 void GenPythonOp::AddDocStringAttrs() {
686 for (const string& name : attrs_) {
687 const auto& attr = *FindAttr(name, op_def_);
688 const auto& api_def_attr = *FindAttr(name, api_def_);
689 string desc =
690 strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": ");
691
692 static const char* const kAttrTypeName[][2] = {
693 {"string", "`string`"},
694 {"list(string)", "list of `strings`"},
695 {"int", "`int`"},
696 {"list(int)", "list of `ints`"},
697 {"float", "`float`"},
698 {"list(float)", "list of `floats`"},
699 {"bool", "`bool`"},
700 {"list(bool)", "list of `bools`"},
701 {"type", "`tf.DType`"},
702 {"list(type)", "list of `tf.DTypes`"},
703 {"shape", "`tf.TensorShape` or list of `ints`"},
704 {"list(shape)",
705 "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
706 {"tensor", "`tf.TensorProto`"},
707 {"list(tensor)", "list of `tf.TensorProto` objects"},
708 {"func", "function decorated with @Defun"},
709 {"list(func)", "list of functions decorated with @Defun"},
710 };
711 for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
712 if (attr.type() == kAttrTypeName[i][0]) {
713 string s;
714 if (api_def_attr.has_default_value()) {
715 s = strings::StrCat("optional ", kAttrTypeName[i][1]);
716 } else {
717 s = kAttrTypeName[i][1];
718 }
719 if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
720 strings::StrAppend(&desc, "An ", s);
721 } else {
722 strings::StrAppend(&desc, "A ", s);
723 }
724 break;
725 }
726 }
727
728 if (attr.has_allowed_values()) {
729 strings::StrAppend(&desc, " from: `",
730 AttrListToPython(attr.allowed_values()), "`");
731 }
732
733 if (attr.has_minimum()) {
734 if (attr.type() == "int") {
735 strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
736 } else if (attr.minimum() > 0) {
737 strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
738 }
739 }
740
741 strings::StrAppend(&desc, ".");
742
743 if (api_def_attr.has_default_value()) {
744 strings::StrAppend(
745 &desc, " Defaults to `",
746 AttrValueToPython(attr.type(), api_def_attr.default_value()), "`.");
747 }
748 if (!api_def_attr.description().empty()) {
749 AppendWithinWidth(&desc, api_def_attr.description(),
750 kRightMargin - 4 /* indent */);
751 }
752 strings::StrAppend(&result_, Indent(4, 6, desc));
753 }
754 }
755
AddDocStringNameArg()756 void GenPythonOp::AddDocStringNameArg() {
757 strings::StrAppend(&result_,
758 " name: A name for the operation (optional).\n");
759 }
760
AddOutputGlobals()761 void GenPythonOp::AddOutputGlobals() {
762 // Prepare a NamedTuple type to hold the outputs, if there are multiple
763 if (num_outs_ > 1) {
764 // Prepare the list of output names
765 std::vector<string> out_names(num_outs_);
766 for (int i = 0; i < num_outs_; ++i) {
767 if (!api_def_.out_arg(i).rename_to().empty()) {
768 out_names[i] = api_def_.out_arg(i).rename_to();
769 } else {
770 out_names[i] = strings::StrCat("output", i);
771 }
772 }
773 string out_names_list =
774 strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
775
776 // Provide the output names as a Python list
777 string lower_op_name_outputs =
778 strings::StrCat("_", function_name_, "_outputs");
779 const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
780 strings::StrAppend(&prelude_, "\n",
781 WordWrap(outputs_prefix, out_names_list, kRightMargin),
782 "\n");
783
784 strings::StrAppend(&prelude_, "_", op_def_.name(),
785 "Output = _collections.namedtuple(\n");
786 const string tuple_type_prefix = " ";
787 const string tuple_type_suffix = strings::StrCat(
788 "\"", op_def_.name(), "\", ", lower_op_name_outputs, ")");
789 strings::StrAppend(
790 &prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
791 "\n\n");
792 }
793 strings::StrAppend(&prelude_, "\n");
794 }
795
AddDocStringOutputs()796 void GenPythonOp::AddDocStringOutputs() {
797 std::vector<string> output_type_string;
798 output_type_string.reserve(num_outs_);
799 for (int i = 0; i < num_outs_; ++i) {
800 output_type_string.push_back(
801 ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
802 }
803 strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
804 }
805
AddBody(const string & prefix)806 void GenPythonOp::AddBody(const string& prefix) {
807 const string apply_prefix = strings::StrCat(
808 prefix, "_result = _op_def_lib.apply_op(\"", op_def_.name(), "\", ");
809 AddBodyNoReturn(apply_prefix);
810 if (num_outs_ > 1) {
811 strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(),
812 "Output._make(_result)\n");
813 }
814 strings::StrAppend(&result_, prefix, "return _result\n");
815 }
816
AddBodyNoReturn(const string & apply_prefix)817 void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
818 string args;
819 for (size_t i = 0; i < param_names_.size(); ++i) {
820 strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
821 "=", param_names_[i].GetRenameTo(), ", ");
822 }
823 strings::StrAppend(&args, "name=name)");
824
825 strings::StrAppend(&result_,
826 // Wrap the arguments, and indent to the (.
827 WordWrap(apply_prefix, args, kRightMargin), "\n");
828 }
829
830 } // namespace python_op_gen_internal
831 } // namespace tensorflow
832