1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/framework/python_op_gen_internal.h"
17
18 #include <float.h>
19 #include <stdio.h>
20
21 #include <iomanip>
22 #include <sstream>
23 #include <unordered_map>
24
25 #include "absl/strings/escaping.h"
26 #include "absl/strings/str_replace.h"
27 #include "tensorflow/core/framework/api_def.pb.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/op_def_util.h"
32 #include "tensorflow/core/framework/op_gen_lib.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/lib/strings/stringprintf.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace tensorflow {
46 namespace python_op_gen_internal {
47
48 const int kRightMargin = 78;
49 // Names specified in tf_export decorators are exported to
50 // TensorFlow 2.0 by default.
51 const int kLatestAPIExportVersion = 2;
52
IsPythonReserved(const string & s)53 bool IsPythonReserved(const string& s) {
54 static const std::set<string>* const kPythonReserved = new std::set<string>(
55 {// Keywords in Python, from:
56 // import keyword
57 // print keyword.kwlist
58 "and", "as", "assert", "break", "class", "continue", "def", "del",
59 "elif", "else", "except", "exec", "finally", "for", "from", "global",
60 "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
61 "raise", "return", "try", "while", "with", "yield",
62 // Built-in functions and types in Python, from:
63 // [x for x in dir(__builtins__) if not x[0].islower()]
64 "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
65 "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
66 "Ellipsis", "EnvironmentError", "Exception", "False",
67 "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
68 "ImportError", "ImportWarning", "IndentationError", "IndexError",
69 "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
70 "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
71 "OverflowError", "PendingDeprecationWarning", "ReferenceError",
72 "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
73 "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
74 "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
75 "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
76 "UnicodeWarning", "UserWarning", "ValueError", "Warning",
77 "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
78 "__package__"});
79
80 return kPythonReserved->count(s) > 0;
81 }
82
IsOpWithUnderscorePrefix(const string & s)83 bool IsOpWithUnderscorePrefix(const string& s) {
84 static const std::set<string>* const kUnderscoreOps = new std::set<string>(
85 {// Lowercase built-in functions and types in Python, from:
86 // [x for x in dir(__builtins__) if x[0].islower()] except "round".
87 // These need to be excluded so they don't conflict with actual built-in
88 // functions since we use '*' imports.
89 "abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray",
90 "bytes", "callable", "chr", "classmethod", "cmp", "coerce", "compile",
91 "complex", "copyright", "credits", "delattr", "dict", "dir", "divmod",
92 "enumerate", "eval", "execfile", "exit", "file", "filter", "float",
93 "format", "frozenset", "getattr", "globals", "hasattr", "hash", "help",
94 "hex", "id", "input", "int", "intern", "isinstance", "issubclass",
95 "iter", "len", "license", "list", "locals", "long", "map", "max",
96 "memoryview", "min", "next", "object", "oct", "open", "ord", "pow",
97 "print", "property", "quit", "range", "raw_input", "reduce", "reload",
98 "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod",
99 "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars",
100 "xrange", "zip",
101 // These have the same name as ops defined in Python and might be used
102 // incorrectly depending on order of '*' imports.
103 // TODO(annarev): reduce usage of '*' imports and remove these from the
104 // list.
105 "fused_batch_norm", "histogram_fixed_width", "stack",
106 "batch_norm_with_global_normalization", "clip_by_value"});
107 return kUnderscoreOps->count(s) > 0;
108 }
109
AvoidPythonReserved(const string & s)110 string AvoidPythonReserved(const string& s) {
111 // Convert namespace separators ('>' characters) to joiners
112 string result = absl::StrReplaceAll(s, {{">", "_"}});
113
114 if (IsPythonReserved(result)) return strings::StrCat(result, "_");
115 return result;
116 }
117
118 // Indent the first line by "initial" spaces and all following lines
119 // by "rest" spaces.
Indent(int initial,int rest,StringPiece in)120 string Indent(int initial, int rest, StringPiece in) {
121 // TODO(josh11b): Also word-wrapping?
122 string copy(in.data(), in.size());
123 absl::StripTrailingAsciiWhitespace(©);
124 std::vector<string> v = str_util::Split(copy, '\n');
125
126 string result;
127 bool first = true;
128 for (const string& line : v) {
129 if (first) {
130 result = strings::StrCat(Spaces(initial), line, "\n");
131 first = false;
132 } else {
133 if (line.empty()) {
134 strings::StrAppend(&result, "\n");
135 } else {
136 strings::StrAppend(&result, Spaces(rest), line, "\n");
137 }
138 }
139 }
140 return result;
141 }
142
143 // Adds append to *dest, with a space if the first line will be <= width,
144 // or a newline otherwise.
AppendWithinWidth(string * dest,StringPiece append,int width)145 void AppendWithinWidth(string* dest, StringPiece append, int width) {
146 auto first_line = append.find('\n');
147 if (first_line == string::npos) first_line = append.size();
148 if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
149 strings::StrAppend(dest, "\n", append);
150 } else {
151 strings::StrAppend(dest, " ", append);
152 }
153 }
154
155 // Like DataTypeString() but uses the Python names for the
156 // float types.
PythonDataTypeString(DataType dtype)157 string PythonDataTypeString(DataType dtype) {
158 switch (dtype) {
159 case DT_FLOAT:
160 return "float32";
161 case DT_DOUBLE:
162 return "float64";
163 default:
164 return DataTypeString(dtype);
165 }
166 }
167
TypeString(DataType dtype,bool ref)168 string TypeString(DataType dtype, bool ref) {
169 if (ref) {
170 return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
171 } else {
172 return strings::StrCat("`", PythonDataTypeString(dtype), "`");
173 }
174 }
175
TypeListString(const AttrValue & value)176 string TypeListString(const AttrValue& value) {
177 string ret;
178 for (int t : value.list().type()) {
179 if (!ret.empty()) strings::StrAppend(&ret, ", ");
180 DataType dtype = static_cast<DataType>(t);
181 if (IsRefType(dtype)) {
182 strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
183 " mutable");
184 } else {
185 strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
186 }
187 }
188 return ret;
189 }
190
SingleTensorName(DataType dtype,bool is_ref)191 string SingleTensorName(DataType dtype, bool is_ref) {
192 const string type_str = TypeString(dtype, is_ref);
193 return strings::StrCat("A `Tensor` of type ", type_str, ".");
194 }
195
196 const char kUnknownTensorType[] = {"A `Tensor`."};
197
ArgTypeName(const OpDef & op_def,const OpDef::ArgDef & arg,const std::unordered_map<string,string> & inferred_attrs,bool is_output)198 string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
199 const std::unordered_map<string, string>& inferred_attrs,
200 bool is_output) {
201 if (!arg.number_attr().empty()) {
202 // N Tensors with the same type
203 const string* original_arg =
204 gtl::FindOrNull(inferred_attrs, arg.number_attr());
205 string prefix;
206 if (original_arg == nullptr) {
207 prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
208 } else if (*original_arg == arg.name()) {
209 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
210 if (attr->has_minimum() && attr->minimum() > 0) {
211 prefix = strings::StrCat("A list of at least ", attr->minimum());
212 } else {
213 prefix = "A list of";
214 }
215 } else {
216 prefix = strings::StrCat("A list with the same length as `",
217 AvoidPythonReserved(*original_arg), "` of");
218 }
219
220 if (arg.type() != DT_INVALID) {
221 return strings::StrCat(prefix, " `Tensor` objects with type ",
222 TypeString(arg.type(), arg.is_ref()), ".");
223 } else {
224 original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
225 if (arg.is_ref()) {
226 strings::StrAppend(&prefix, " mutable");
227 }
228 if (original_arg == nullptr) {
229 return strings::StrCat(prefix, " `Tensor` objects with type `",
230 arg.type_attr(), "`.");
231 } else if (*original_arg == arg.name()) {
232 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
233 if (attr->has_allowed_values()) {
234 return strings::StrCat(prefix,
235 " `Tensor` objects with the same type in: ",
236 TypeListString(attr->allowed_values()), ".");
237 } else {
238 return strings::StrCat(prefix,
239 " `Tensor` objects with the same type.");
240 }
241 } else {
242 return strings::StrCat(prefix,
243 " `Tensor` objects with the same type as `",
244 AvoidPythonReserved(*original_arg), "`.");
245 }
246 }
247 } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
248 const bool is_list = !arg.type_list_attr().empty();
249 const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
250 const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
251 const string mutable_str = arg.is_ref() ? "mutable " : "";
252 const string prefix =
253 is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
254 : strings::StrCat("A ", mutable_str, "`Tensor`");
255 const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
256 if (original_arg == nullptr) {
257 return strings::StrCat(prefix, " of type `", attr_name, "`.");
258 } else if (*original_arg == arg.name()) {
259 if (attr->has_allowed_values()) {
260 if (is_list) {
261 return strings::StrCat(prefix, " with types from: ",
262 TypeListString(attr->allowed_values()), ".");
263 } else {
264 return strings::StrCat(
265 prefix, is_output ? ". Has one of the following types: "
266 : ". Must be one of the following types: ",
267 TypeListString(attr->allowed_values()), ".");
268 }
269 } else {
270 return strings::StrCat(prefix, ".");
271 }
272 } else {
273 return strings::StrCat(prefix,
274 is_output ? ". Has the same type as `"
275 : ". Must have the same type as `",
276 AvoidPythonReserved(*original_arg), "`.");
277 }
278 } else {
279 return SingleTensorName(arg.type(), arg.is_ref());
280 }
281 }
282
GetReturns(const OpDef & op_def,const std::vector<string> & output_type_string)283 string GetReturns(const OpDef& op_def,
284 const std::vector<string>& output_type_string) {
285 string result;
286 DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
287 const int num_outs = op_def.output_arg_size();
288 strings::StrAppend(&result, "\n Returns:\n");
289 if (num_outs == 0) {
290 strings::StrAppend(&result, " The created Operation.\n");
291 } else {
292 if (num_outs == 1) {
293 StringPiece description = op_def.output_arg(0).description();
294 if (ConsumeEquals(&description)) { // Skip the generated type info.
295 strings::StrAppend(&result, Indent(4, 4, description));
296 } else {
297 // Special case of one output, don't use the name of the output unless
298 // there is no description.
299 string desc = output_type_string.empty() ? kUnknownTensorType
300 : output_type_string[0];
301 if (desc == kUnknownTensorType) {
302 // Special case where we don't understand how the output tensor type
303 // depends on the input tensor types, just use the output arg
304 // description if we can.
305 if (!description.empty()) {
306 desc = op_def.output_arg(0).description();
307 } else if (!op_def.output_arg(0).name().empty()) {
308 desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
309 " `Tensor`.");
310 }
311 } else if (!description.empty()) {
312 AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
313 }
314 strings::StrAppend(&result, Indent(4, 4, desc));
315 }
316 } else {
317 std::vector<string> out_names(num_outs);
318 for (int i = 0; i < num_outs; ++i) {
319 if (!op_def.output_arg(i).name().empty()) {
320 out_names[i] = op_def.output_arg(i).name();
321 } else {
322 out_names[i] = strings::StrCat("output", i);
323 }
324 }
325 strings::StrAppend(&result, " A tuple of `Tensor` objects (",
326 absl::StrJoin(out_names, ", "), ").\n\n");
327 for (int i = 0; i < num_outs; ++i) {
328 string desc = strings::StrCat(out_names[i], ": ");
329 StringPiece description = op_def.output_arg(i).description();
330 if (ConsumeEquals(&description)) { // Skip the generated type info.
331 strings::StrAppend(&desc, description);
332 } else {
333 const string type = static_cast<size_t>(i) < output_type_string.size()
334 ? output_type_string[i]
335 : kUnknownTensorType;
336 if (!description.empty()) {
337 if (type == kUnknownTensorType) {
338 // Special case where we don't understand how the output tensor
339 // type depends on the input tensor types, so we just use the
340 // output arg description.
341 strings::StrAppend(&desc, description);
342 } else {
343 strings::StrAppend(&desc, type, " ", description);
344 }
345 } else {
346 strings::StrAppend(&desc, type);
347 }
348 }
349 strings::StrAppend(&result, Indent(4, 6, desc));
350 }
351 }
352 }
353 return result;
354 }
355
StringToPython(const string & str)356 string StringToPython(const string& str) {
357 return strings::StrCat("\"", absl::CEscape(str), "\"");
358 }
359
DataTypeToPython(DataType dtype,const string & dtype_module)360 string DataTypeToPython(DataType dtype, const string& dtype_module) {
361 return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
362 }
363
ShapeToPython(const TensorShapeProto & shape)364 string ShapeToPython(const TensorShapeProto& shape) {
365 if (shape.unknown_rank()) {
366 return "None";
367 }
368 string python = "[";
369 for (const auto& dim : shape.dim()) {
370 if (python.size() > 1) strings::StrAppend(&python, ", ");
371 if (!dim.name().empty()) {
372 strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
373 dim.size(), ")");
374 } else {
375 strings::StrAppend(&python, dim.size());
376 }
377 }
378 strings::StrAppend(&python, "]");
379 return python;
380 }
381
TensorToPython(const TensorProto & proto)382 string TensorToPython(const TensorProto& proto) {
383 return proto.ShortDebugString();
384 }
385
AttrListToPython(const AttrValue & value,const string & dtype_module="tf.")386 string AttrListToPython(const AttrValue& value,
387 const string& dtype_module = "tf.") {
388 string ret;
389 if (value.list().s_size() > 0) {
390 for (int i = 0; i < value.list().s_size(); ++i) {
391 if (i > 0) strings::StrAppend(&ret, ", ");
392 strings::StrAppend(&ret, StringToPython(value.list().s(i)));
393 }
394 } else if (value.list().i_size() > 0) {
395 for (int i = 0; i < value.list().i_size(); ++i) {
396 if (i > 0) strings::StrAppend(&ret, ", ");
397 strings::StrAppend(&ret, value.list().i(i));
398 }
399 } else if (value.list().f_size() > 0) {
400 for (int i = 0; i < value.list().f_size(); ++i) {
401 if (i > 0) strings::StrAppend(&ret, ", ");
402 strings::StrAppend(&ret, value.list().f(i));
403 }
404 } else if (value.list().b_size() > 0) {
405 for (int i = 0; i < value.list().b_size(); ++i) {
406 if (i > 0) strings::StrAppend(&ret, ", ");
407 strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
408 }
409 } else if (value.list().type_size() > 0) {
410 for (int i = 0; i < value.list().type_size(); ++i) {
411 if (i > 0) strings::StrAppend(&ret, ", ");
412 strings::StrAppend(&ret,
413 DataTypeToPython(value.list().type(i), dtype_module));
414 }
415 } else if (value.list().shape_size() > 0) {
416 for (int i = 0; i < value.list().shape_size(); ++i) {
417 if (i > 0) strings::StrAppend(&ret, ", ");
418 strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
419 }
420 } else if (value.list().tensor_size() > 0) {
421 for (int i = 0; i < value.list().tensor_size(); ++i) {
422 if (i > 0) strings::StrAppend(&ret, ", ");
423 strings::StrAppend(&ret, TensorToPython(value.list().tensor(i)));
424 }
425 } else if (value.list().func_size() > 0) {
426 for (int i = 0; i < value.list().func_size(); ++i) {
427 if (i > 0) strings::StrAppend(&ret, ", ");
428 strings::StrAppend(&ret, StringToPython(value.list().func(i).name()));
429 }
430 }
431 return ret;
432 }
433
434 // NOTE: The return value may contain spaces (for example, it could be
435 // a string "foo bar" with an embedded space) and is not safe to pass
436 // to WordWrap().
AttrValueToPython(const string & type,const AttrValue & value,const string & dtype_module)437 string AttrValueToPython(const string& type, const AttrValue& value,
438 const string& dtype_module) {
439 if (type == "string") {
440 return StringToPython(value.s());
441 } else if (type == "int") {
442 return strings::StrCat(value.i());
443 } else if (type == "float") {
444 if (std::isnan(value.f()) || std::isinf(value.f())) {
445 return strings::StrCat("float('", value.f(), "')");
446 } else {
447 // Use locale-independent conversion.
448 static_assert(FLT_DIG < 10, "FLT_DIG is too big");
449 std::ostringstream s;
450 s.imbue(std::locale::classic());
451 s << std::setprecision(FLT_DIG) << value.f();
452 // If there is no I/O error for `std::ostringstream s` return s.str(),
453 // otherwise fallback to strings::StrCat(value.f()).
454 if (s.good()) {
455 return s.str();
456 }
457 return strings::StrCat(value.f());
458 }
459 } else if (type == "bool") {
460 return value.b() ? "True" : "False";
461 } else if (type == "type") {
462 return DataTypeToPython(value.type(), dtype_module);
463 } else if (type == "shape") {
464 return ShapeToPython(value.shape());
465 } else if (type == "tensor") {
466 return TensorToPython(value.tensor());
467 } else if (type == "func") {
468 return StringToPython(value.func().name());
469 } else if (absl::StartsWith(type, "list(")) {
470 return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
471 } else {
472 return "?";
473 }
474 }
475
GenerateLowerCaseOpName(const string & str,string * result)476 void GenerateLowerCaseOpName(const string& str, string* result) {
477 const char joiner = '_';
478 const char namespace_separator = '>';
479 const int last_index = str.size() - 1;
480 for (int i = 0; i <= last_index; ++i) {
481 const char c = str[i];
482 // Convert namespace separators ('>' characters) to joiners
483 if (c == namespace_separator) {
484 result->push_back(joiner);
485 continue;
486 }
487
488 // Emit a joiner only if a previous-lower-to-now-upper or a
489 // now-upper-to-next-lower transition happens.
490 // (But don't emit an extra joiner if we just saw a namespace separator
491 if (isupper(c) && (i > 0)) {
492 if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
493 if (!(str[i - 1] == namespace_separator)) {
494 result->push_back(joiner);
495 }
496 }
497 }
498 result->push_back(tolower(c));
499 }
500 }
501
AddDelimiter(string * append_to,const string & delim)502 static void AddDelimiter(string* append_to, const string& delim) {
503 if (!append_to->empty()) strings::StrAppend(append_to, delim);
504 }
505
FindAttr(StringPiece name,const ApiDef & api_def)506 const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
507 for (int i = 0; i < api_def.attr_size(); ++i) {
508 if (api_def.attr(i).name() == name) {
509 return &api_def.attr(i);
510 }
511 }
512 return nullptr;
513 }
514
GenPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)515 GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
516 const string& function_name, bool add_type_annotations)
517 : op_def_(op_def),
518 api_def_(api_def),
519 function_name_(function_name),
520 add_type_annotations_(add_type_annotations),
521 num_outs_(op_def.output_arg_size()) {}
522
~GenPythonOp()523 GenPythonOp::~GenPythonOp() {}
524
Code()525 string GenPythonOp::Code() {
526 // This has all the input args followed by those attrs that don't have
527 // defaults.
528 std::vector<ParamNames> params_no_default;
529 // The parameters with defaults (these have to be listed after those without).
530 // No input args are included, just attrs.
531 std::vector<ParamNames> params_with_default;
532
533 for (int i = 0; i < api_def_.arg_order_size(); ++i) {
534 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
535 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
536 params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
537 if (!arg.type_attr().empty()) {
538 gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
539 } else if (!arg.type_list_attr().empty()) {
540 gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
541 arg.name());
542 }
543 if (!arg.number_attr().empty()) {
544 gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
545 }
546 }
547 for (int i = 0; i < api_def_.attr_size(); ++i) {
548 const auto& attr(api_def_.attr(i));
549 // Do not add inferred attrs to the Python function signature.
550 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
551 if (attr.has_default_value()) {
552 params_with_default.emplace_back(attr.name(), attr.rename_to());
553 } else {
554 params_no_default.emplace_back(attr.name(), attr.rename_to());
555 }
556 }
557 }
558
559 // Save the list of attr parameters (attrs that won't be inferred),
560 // those with defaults go at the end.
561 // Get the attrs in the order we want by taking the attrs without defaults
562 // from the end of args_no_default, and adding args_no_default.
563 attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
564 params_with_default.size());
565 for (int i = op_def_.input_arg_size(), end = params_no_default.size();
566 i < end; ++i) {
567 attrs_.push_back(params_no_default[i].GetName());
568 }
569 for (int i = 0, end = params_with_default.size(); i < end; ++i) {
570 attrs_.push_back(params_with_default[i].GetName());
571 }
572
573 param_names_.reserve(params_no_default.size() + params_with_default.size());
574 param_names_.insert(param_names_.begin(), params_no_default.begin(),
575 params_no_default.end());
576 for (const auto& param : params_with_default) {
577 param_names_.push_back(param);
578 }
579
580 string parameters;
581 for (const auto& param : params_no_default) {
582 AddDelimiter(¶meters, ", ");
583 strings::StrAppend(¶meters, param.GetRenameTo());
584 }
585 for (const auto& param_and_default : params_with_default) {
586 AddDelimiter(¶meters, ", ");
587 strings::StrAppend(¶meters, param_and_default.GetRenameTo(), "=None");
588 }
589 AddDelimiter(¶meters, ", ");
590 strings::StrAppend(¶meters, "name=None");
591
592 AddExport();
593 AddDefLine(parameters);
594 AddDocStringDescription();
595 AddDocStringArgs();
596 AddDocStringInputs();
597 AddDocStringAttrs();
598 AddDocStringNameArg();
599 AddOutputGlobals();
600 AddDocStringOutputs();
601 strings::StrAppend(&result_, " \"\"\"\n");
602 AddBody(" ");
603 strings::StrAppend(&result_, "\n\n");
604
605 return prelude_ + result_;
606 }
607
AddExport()608 void GenPythonOp::AddExport() {
609 if (api_def_.visibility() != ApiDef::VISIBLE) {
610 return;
611 }
612 // Whether op should be available in latest export version.
613 bool op_available_in_latest =
614 !api_def_.deprecation_version() ||
615 api_def_.deprecation_version() > kLatestAPIExportVersion;
616
617 string names;
618 string names_v1;
619 string deprecated_endpoints;
620
621 for (const auto& endpoint : api_def_.endpoint()) {
622 string endpoint_name;
623 python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(),
624 &endpoint_name);
625 if (endpoint.deprecated() || endpoint.deprecation_version() > 0) {
626 AddDelimiter(&deprecated_endpoints, ", ");
627 strings::StrAppend(&deprecated_endpoints, "'", endpoint_name, "'");
628 }
629 // Add all endpoints to TensorFlow 1.* API.
630 AddDelimiter(&names_v1, ", ");
631 strings::StrAppend(&names_v1, "'", endpoint_name, "'");
632 // Add non-deprecated endpoints to TensorFlow 2.* API.
633 if (op_available_in_latest &&
634 (!endpoint.deprecation_version() ||
635 endpoint.deprecation_version() > kLatestAPIExportVersion)) {
636 AddDelimiter(&names, ", ");
637 strings::StrAppend(&names, "'", endpoint_name, "'");
638 }
639 }
640
641 // tf_export decorator has the following format:
642 // @tf_export(v2_name, v2_name, v1=[v1_name, v1_name])
643 if (names != names_v1) {
644 AddDelimiter(&names, ", ");
645 strings::StrAppend(&names, "v1=[", names_v1, "]");
646 }
647 strings::StrAppend(&result_, "@tf_export(", names, ")\n");
648
649 // If all endpoints are deprecated, add @deprecated decorator.
650 if (!api_def_.deprecation_message().empty()) {
651 const string instructions = api_def_.deprecation_message();
652 strings::StrAppend(&result_, "@deprecated(None, '", instructions, "')\n");
653 }
654 // Add @deprecated_endpoints decorator.
655 if (!deprecated_endpoints.empty()) {
656 strings::StrAppend(&result_, "@deprecated_endpoints(", deprecated_endpoints,
657 ")\n");
658 }
659 }
660
AddDefLine(const string & function_name,const string & parameters)661 void GenPythonOp::AddDefLine(const string& function_name,
662 const string& parameters) {
663 strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n");
664 }
665
AddDefLine(const string & parameters)666 void GenPythonOp::AddDefLine(const string& parameters) {
667 AddDefLine(function_name_, parameters);
668 }
669
AddDocStringDescription()670 void GenPythonOp::AddDocStringDescription() {
671 string comment;
672 if (api_def_.summary().empty()) {
673 comment = "TODO: add doc.\n";
674 } else {
675 comment = strings::StrCat(api_def_.summary(), "\n");
676 if (!api_def_.description().empty()) {
677 strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description()));
678 }
679 }
680 strings::StrAppend(&result_, " r\"\"\"", comment, "\n");
681 }
682
AddDocStringArgs()683 void GenPythonOp::AddDocStringArgs() {
684 strings::StrAppend(&result_, " Args:\n");
685 }
686
AddDocStringInputs()687 void GenPythonOp::AddDocStringInputs() {
688 for (int i = 0; i < api_def_.arg_order_size(); ++i) {
689 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
690 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
691 StringPiece description = api_def_arg.description();
692 string desc;
693 if (ConsumeEquals(&description)) { // Skip the generated type info.
694 desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
695 } else {
696 desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
697 ArgTypeName(op_def_, arg, inferred_attrs_, false));
698 }
699 if (!description.empty()) {
700 AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
701 }
702 strings::StrAppend(&result_, Indent(4, 6, desc));
703 }
704 }
705
AddDocStringAttrs()706 void GenPythonOp::AddDocStringAttrs() {
707 for (const string& name : attrs_) {
708 const auto& attr = *FindAttr(name, op_def_);
709 const auto& api_def_attr = *FindAttr(name, api_def_);
710 string desc =
711 strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": ");
712
713 static const char* const kAttrTypeName[][2] = {
714 {"string", "`string`"},
715 {"list(string)", "list of `strings`"},
716 {"int", "`int`"},
717 {"list(int)", "list of `ints`"},
718 {"float", "`float`"},
719 {"list(float)", "list of `floats`"},
720 {"bool", "`bool`"},
721 {"list(bool)", "list of `bools`"},
722 {"type", "`tf.DType`"},
723 {"list(type)", "list of `tf.DTypes`"},
724 {"shape", "`tf.TensorShape` or list of `ints`"},
725 {"list(shape)",
726 "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
727 {"tensor", "`tf.TensorProto`"},
728 {"list(tensor)", "list of `tf.TensorProto` objects"},
729 {"func", "function decorated with @Defun"},
730 {"list(func)", "list of functions decorated with @Defun"},
731 };
732 for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
733 if (attr.type() == kAttrTypeName[i][0]) {
734 string s;
735 if (api_def_attr.has_default_value()) {
736 s = strings::StrCat("optional ", kAttrTypeName[i][1]);
737 } else {
738 s = kAttrTypeName[i][1];
739 }
740 if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
741 strings::StrAppend(&desc, "An ", s);
742 } else {
743 strings::StrAppend(&desc, "A ", s);
744 }
745 break;
746 }
747 }
748
749 if (attr.has_allowed_values()) {
750 strings::StrAppend(&desc, " from: `",
751 AttrListToPython(attr.allowed_values()), "`");
752 }
753
754 if (attr.has_minimum()) {
755 if (attr.type() == "int") {
756 strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
757 } else if (attr.minimum() > 0) {
758 strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
759 }
760 }
761
762 strings::StrAppend(&desc, ".");
763
764 if (api_def_attr.has_default_value()) {
765 strings::StrAppend(
766 &desc, " Defaults to `",
767 AttrValueToPython(attr.type(), api_def_attr.default_value()), "`.");
768 }
769 if (!api_def_attr.description().empty()) {
770 AppendWithinWidth(&desc, api_def_attr.description(),
771 kRightMargin - 4 /* indent */);
772 }
773 strings::StrAppend(&result_, Indent(4, 6, desc));
774 }
775 }
776
AddDocStringNameArg()777 void GenPythonOp::AddDocStringNameArg() {
778 strings::StrAppend(&result_,
779 " name: A name for the operation (optional).\n");
780 }
781
AddOutputGlobals()782 void GenPythonOp::AddOutputGlobals() {
783 // Generate a namedtuple class to hold the outputs, if there are multiple.
784 // Example:
785 //
786 // _OpOutputs = collections.namedtuple(
787 // "_OpOutputs",
788 // "out1 out2 out3")
789 if (num_outs_ > 1) {
790 std::vector<string> out_names;
791 out_names.reserve(num_outs_);
792 for (int i = 0; i < num_outs_; ++i) {
793 const string out_name = !api_def_.out_arg(i).rename_to().empty()
794 ? api_def_.out_arg(i).rename_to()
795 : strings::StrCat("output", i);
796 out_names.push_back(strings::StrCat("\"", out_name, "\""));
797 }
798
799 strings::StrAppend(&prelude_, "_", AvoidPythonReserved(op_def_.name()),
800 "Output = collections.namedtuple(\n");
801 strings::StrAppend(&prelude_, " \"", AvoidPythonReserved(op_def_.name()),
802 "\",\n");
803 strings::StrAppend(&prelude_, " [", absl::StrJoin(out_names, ", "),
804 "])");
805 strings::StrAppend(&prelude_, "\n\n");
806 }
807 strings::StrAppend(&prelude_, "\n");
808 }
809
AddDocStringOutputs()810 void GenPythonOp::AddDocStringOutputs() {
811 std::vector<string> output_type_string;
812 output_type_string.reserve(num_outs_);
813 for (int i = 0; i < num_outs_; ++i) {
814 output_type_string.push_back(
815 ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
816 }
817 strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
818 }
819
AddBody(const string & prefix)820 void GenPythonOp::AddBody(const string& prefix) {
821 const string apply_prefix = strings::StrCat(
822 prefix, "_result = _op_def_lib.apply_op(\"", op_def_.name(), "\", ");
823 AddBodyNoReturn(apply_prefix);
824 if (num_outs_ > 1) {
825 strings::StrAppend(&result_, prefix, "_result = _",
826 AvoidPythonReserved(op_def_.name()),
827 "Output._make(_result)\n");
828 }
829 strings::StrAppend(&result_, prefix, "return _result\n");
830 }
831
AddBodyNoReturn(const string & apply_prefix)832 void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
833 string args;
834 for (size_t i = 0; i < param_names_.size(); ++i) {
835 strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
836 "=", param_names_[i].GetRenameTo(), ", ");
837 }
838 strings::StrAppend(&args, "name=name)");
839
840 strings::StrAppend(&result_,
841 // Wrap the arguments, and indent to the (.
842 WordWrap(apply_prefix, args, kRightMargin), "\n");
843 }
844
845 } // namespace python_op_gen_internal
846 } // namespace tensorflow
847