• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/framework/python_op_gen_internal.h"
17 
18 #include <float.h>
19 #include <stdio.h>
20 
21 #include <iomanip>
22 #include <sstream>
23 #include <unordered_map>
24 
25 #include "absl/strings/escaping.h"
26 #include "absl/strings/str_replace.h"
27 #include "tensorflow/core/framework/api_def.pb.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/op_def_util.h"
32 #include "tensorflow/core/framework/op_gen_lib.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/lib/strings/stringprintf.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace tensorflow {
46 namespace python_op_gen_internal {
47 
48 const int kRightMargin = 78;
49 // Names specified in tf_export decorators are exported to
50 // TensorFlow 2.0 by default.
51 const int kLatestAPIExportVersion = 2;
52 
IsPythonReserved(const string & s)53 bool IsPythonReserved(const string& s) {
54   static const std::set<string>* const kPythonReserved = new std::set<string>(
55       {// Keywords in Python, from:
56        //   import keyword
57        //   print keyword.kwlist
58        "and", "as", "assert", "break", "class", "continue", "def", "del",
59        "elif", "else", "except", "exec", "finally", "for", "from", "global",
60        "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
61        "raise", "return", "try", "while", "with", "yield",
62        // Built-in functions and types in Python, from:
63        //   [x for x in dir(__builtins__) if not x[0].islower()]
64        "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
65        "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
66        "Ellipsis", "EnvironmentError", "Exception", "False",
67        "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
68        "ImportError", "ImportWarning", "IndentationError", "IndexError",
69        "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
70        "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
71        "OverflowError", "PendingDeprecationWarning", "ReferenceError",
72        "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
73        "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
74        "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
75        "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
76        "UnicodeWarning", "UserWarning", "ValueError", "Warning",
77        "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
78        "__package__"});
79 
80   return kPythonReserved->count(s) > 0;
81 }
82 
IsOpWithUnderscorePrefix(const string & s)83 bool IsOpWithUnderscorePrefix(const string& s) {
84   static const std::set<string>* const kUnderscoreOps = new std::set<string>(
85       {// Lowercase built-in functions and types in Python, from:
86        // [x for x in dir(__builtins__) if x[0].islower()] except "round".
87        // These need to be excluded so they don't conflict with actual built-in
88        // functions since we use '*' imports.
89        "abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray",
90        "bytes", "callable", "chr", "classmethod", "cmp", "coerce", "compile",
91        "complex", "copyright", "credits", "delattr", "dict", "dir", "divmod",
92        "enumerate", "eval", "execfile", "exit", "file", "filter", "float",
93        "format", "frozenset", "getattr", "globals", "hasattr", "hash", "help",
94        "hex", "id", "input", "int", "intern", "isinstance", "issubclass",
95        "iter", "len", "license", "list", "locals", "long", "map", "max",
96        "memoryview", "min", "next", "object", "oct", "open", "ord", "pow",
97        "print", "property", "quit", "range", "raw_input", "reduce", "reload",
98        "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod",
99        "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars",
100        "xrange", "zip",
101        // These have the same name as ops defined in Python and might be used
102        // incorrectly depending on order of '*' imports.
103        // TODO(annarev): reduce usage of '*' imports and remove these from the
104        // list.
105        "fused_batch_norm", "histogram_fixed_width", "stack",
106        "batch_norm_with_global_normalization", "clip_by_value"});
107   return kUnderscoreOps->count(s) > 0;
108 }
109 
AvoidPythonReserved(const string & s)110 string AvoidPythonReserved(const string& s) {
111   // Convert namespace separators ('>' characters) to joiners
112   string result = absl::StrReplaceAll(s, {{">", "_"}});
113 
114   if (IsPythonReserved(result)) return strings::StrCat(result, "_");
115   return result;
116 }
117 
118 // Indent the first line by "initial" spaces and all following lines
119 // by "rest" spaces.
Indent(int initial,int rest,StringPiece in)120 string Indent(int initial, int rest, StringPiece in) {
121   // TODO(josh11b): Also word-wrapping?
122   string copy(in.data(), in.size());
123   absl::StripTrailingAsciiWhitespace(&copy);
124   std::vector<string> v = str_util::Split(copy, '\n');
125 
126   string result;
127   bool first = true;
128   for (const string& line : v) {
129     if (first) {
130       result = strings::StrCat(Spaces(initial), line, "\n");
131       first = false;
132     } else {
133       if (line.empty()) {
134         strings::StrAppend(&result, "\n");
135       } else {
136         strings::StrAppend(&result, Spaces(rest), line, "\n");
137       }
138     }
139   }
140   return result;
141 }
142 
143 // Adds append to *dest, with a space if the first line will be <= width,
144 // or a newline otherwise.
AppendWithinWidth(string * dest,StringPiece append,int width)145 void AppendWithinWidth(string* dest, StringPiece append, int width) {
146   auto first_line = append.find('\n');
147   if (first_line == string::npos) first_line = append.size();
148   if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
149     strings::StrAppend(dest, "\n", append);
150   } else {
151     strings::StrAppend(dest, " ", append);
152   }
153 }
154 
155 // Like DataTypeString() but uses the Python names for the
156 // float types.
PythonDataTypeString(DataType dtype)157 string PythonDataTypeString(DataType dtype) {
158   switch (dtype) {
159     case DT_FLOAT:
160       return "float32";
161     case DT_DOUBLE:
162       return "float64";
163     default:
164       return DataTypeString(dtype);
165   }
166 }
167 
TypeString(DataType dtype,bool ref)168 string TypeString(DataType dtype, bool ref) {
169   if (ref) {
170     return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
171   } else {
172     return strings::StrCat("`", PythonDataTypeString(dtype), "`");
173   }
174 }
175 
TypeListString(const AttrValue & value)176 string TypeListString(const AttrValue& value) {
177   string ret;
178   for (int t : value.list().type()) {
179     if (!ret.empty()) strings::StrAppend(&ret, ", ");
180     DataType dtype = static_cast<DataType>(t);
181     if (IsRefType(dtype)) {
182       strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
183                          " mutable");
184     } else {
185       strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
186     }
187   }
188   return ret;
189 }
190 
SingleTensorName(DataType dtype,bool is_ref)191 string SingleTensorName(DataType dtype, bool is_ref) {
192   const string type_str = TypeString(dtype, is_ref);
193   return strings::StrCat("A `Tensor` of type ", type_str, ".");
194 }
195 
196 const char kUnknownTensorType[] = {"A `Tensor`."};
197 
ArgTypeName(const OpDef & op_def,const OpDef::ArgDef & arg,const std::unordered_map<string,string> & inferred_attrs,bool is_output)198 string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
199                    const std::unordered_map<string, string>& inferred_attrs,
200                    bool is_output) {
201   if (!arg.number_attr().empty()) {
202     // N Tensors with the same type
203     const string* original_arg =
204         gtl::FindOrNull(inferred_attrs, arg.number_attr());
205     string prefix;
206     if (original_arg == nullptr) {
207       prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
208     } else if (*original_arg == arg.name()) {
209       const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
210       if (attr->has_minimum() && attr->minimum() > 0) {
211         prefix = strings::StrCat("A list of at least ", attr->minimum());
212       } else {
213         prefix = "A list of";
214       }
215     } else {
216       prefix = strings::StrCat("A list with the same length as `",
217                                AvoidPythonReserved(*original_arg), "` of");
218     }
219 
220     if (arg.type() != DT_INVALID) {
221       return strings::StrCat(prefix, " `Tensor` objects with type ",
222                              TypeString(arg.type(), arg.is_ref()), ".");
223     } else {
224       original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
225       if (arg.is_ref()) {
226         strings::StrAppend(&prefix, " mutable");
227       }
228       if (original_arg == nullptr) {
229         return strings::StrCat(prefix, " `Tensor` objects with type `",
230                                arg.type_attr(), "`.");
231       } else if (*original_arg == arg.name()) {
232         const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
233         if (attr->has_allowed_values()) {
234           return strings::StrCat(prefix,
235                                  " `Tensor` objects with the same type in: ",
236                                  TypeListString(attr->allowed_values()), ".");
237         } else {
238           return strings::StrCat(prefix,
239                                  " `Tensor` objects with the same type.");
240         }
241       } else {
242         return strings::StrCat(prefix,
243                                " `Tensor` objects with the same type as `",
244                                AvoidPythonReserved(*original_arg), "`.");
245       }
246     }
247   } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
248     const bool is_list = !arg.type_list_attr().empty();
249     const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
250     const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
251     const string mutable_str = arg.is_ref() ? "mutable " : "";
252     const string prefix =
253         is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
254                 : strings::StrCat("A ", mutable_str, "`Tensor`");
255     const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
256     if (original_arg == nullptr) {
257       return strings::StrCat(prefix, " of type `", attr_name, "`.");
258     } else if (*original_arg == arg.name()) {
259       if (attr->has_allowed_values()) {
260         if (is_list) {
261           return strings::StrCat(prefix, " with types from: ",
262                                  TypeListString(attr->allowed_values()), ".");
263         } else {
264           return strings::StrCat(
265               prefix, is_output ? ". Has one of the following types: "
266                                 : ". Must be one of the following types: ",
267               TypeListString(attr->allowed_values()), ".");
268         }
269       } else {
270         return strings::StrCat(prefix, ".");
271       }
272     } else {
273       return strings::StrCat(prefix,
274                              is_output ? ". Has the same type as `"
275                                        : ". Must have the same type as `",
276                              AvoidPythonReserved(*original_arg), "`.");
277     }
278   } else {
279     return SingleTensorName(arg.type(), arg.is_ref());
280   }
281 }
282 
GetReturns(const OpDef & op_def,const std::vector<string> & output_type_string)283 string GetReturns(const OpDef& op_def,
284                   const std::vector<string>& output_type_string) {
285   string result;
286   DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
287   const int num_outs = op_def.output_arg_size();
288   strings::StrAppend(&result, "\n  Returns:\n");
289   if (num_outs == 0) {
290     strings::StrAppend(&result, "    The created Operation.\n");
291   } else {
292     if (num_outs == 1) {
293       StringPiece description = op_def.output_arg(0).description();
294       if (ConsumeEquals(&description)) {  // Skip the generated type info.
295         strings::StrAppend(&result, Indent(4, 4, description));
296       } else {
297         // Special case of one output, don't use the name of the output unless
298         // there is no description.
299         string desc = output_type_string.empty() ? kUnknownTensorType
300                                                  : output_type_string[0];
301         if (desc == kUnknownTensorType) {
302           // Special case where we don't understand how the output tensor type
303           // depends on the input tensor types, just use the output arg
304           // description if we can.
305           if (!description.empty()) {
306             desc = op_def.output_arg(0).description();
307           } else if (!op_def.output_arg(0).name().empty()) {
308             desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
309                                    " `Tensor`.");
310           }
311         } else if (!description.empty()) {
312           AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
313         }
314         strings::StrAppend(&result, Indent(4, 4, desc));
315       }
316     } else {
317       std::vector<string> out_names(num_outs);
318       for (int i = 0; i < num_outs; ++i) {
319         if (!op_def.output_arg(i).name().empty()) {
320           out_names[i] = op_def.output_arg(i).name();
321         } else {
322           out_names[i] = strings::StrCat("output", i);
323         }
324       }
325       strings::StrAppend(&result, "    A tuple of `Tensor` objects (",
326                          absl::StrJoin(out_names, ", "), ").\n\n");
327       for (int i = 0; i < num_outs; ++i) {
328         string desc = strings::StrCat(out_names[i], ": ");
329         StringPiece description = op_def.output_arg(i).description();
330         if (ConsumeEquals(&description)) {  // Skip the generated type info.
331           strings::StrAppend(&desc, description);
332         } else {
333           const string type = static_cast<size_t>(i) < output_type_string.size()
334                                   ? output_type_string[i]
335                                   : kUnknownTensorType;
336           if (!description.empty()) {
337             if (type == kUnknownTensorType) {
338               // Special case where we don't understand how the output tensor
339               // type depends on the input tensor types, so we just use the
340               // output arg description.
341               strings::StrAppend(&desc, description);
342             } else {
343               strings::StrAppend(&desc, type, " ", description);
344             }
345           } else {
346             strings::StrAppend(&desc, type);
347           }
348         }
349         strings::StrAppend(&result, Indent(4, 6, desc));
350       }
351     }
352   }
353   return result;
354 }
355 
StringToPython(const string & str)356 string StringToPython(const string& str) {
357   return strings::StrCat("\"", absl::CEscape(str), "\"");
358 }
359 
DataTypeToPython(DataType dtype,const string & dtype_module)360 string DataTypeToPython(DataType dtype, const string& dtype_module) {
361   return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
362 }
363 
ShapeToPython(const TensorShapeProto & shape)364 string ShapeToPython(const TensorShapeProto& shape) {
365   if (shape.unknown_rank()) {
366     return "None";
367   }
368   string python = "[";
369   for (const auto& dim : shape.dim()) {
370     if (python.size() > 1) strings::StrAppend(&python, ", ");
371     if (!dim.name().empty()) {
372       strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
373                          dim.size(), ")");
374     } else {
375       strings::StrAppend(&python, dim.size());
376     }
377   }
378   strings::StrAppend(&python, "]");
379   return python;
380 }
381 
TensorToPython(const TensorProto & proto)382 string TensorToPython(const TensorProto& proto) {
383   return proto.ShortDebugString();
384 }
385 
AttrListToPython(const AttrValue & value,const string & dtype_module="tf.")386 string AttrListToPython(const AttrValue& value,
387                         const string& dtype_module = "tf.") {
388   string ret;
389   if (value.list().s_size() > 0) {
390     for (int i = 0; i < value.list().s_size(); ++i) {
391       if (i > 0) strings::StrAppend(&ret, ", ");
392       strings::StrAppend(&ret, StringToPython(value.list().s(i)));
393     }
394   } else if (value.list().i_size() > 0) {
395     for (int i = 0; i < value.list().i_size(); ++i) {
396       if (i > 0) strings::StrAppend(&ret, ", ");
397       strings::StrAppend(&ret, value.list().i(i));
398     }
399   } else if (value.list().f_size() > 0) {
400     for (int i = 0; i < value.list().f_size(); ++i) {
401       if (i > 0) strings::StrAppend(&ret, ", ");
402       strings::StrAppend(&ret, value.list().f(i));
403     }
404   } else if (value.list().b_size() > 0) {
405     for (int i = 0; i < value.list().b_size(); ++i) {
406       if (i > 0) strings::StrAppend(&ret, ", ");
407       strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
408     }
409   } else if (value.list().type_size() > 0) {
410     for (int i = 0; i < value.list().type_size(); ++i) {
411       if (i > 0) strings::StrAppend(&ret, ", ");
412       strings::StrAppend(&ret,
413                          DataTypeToPython(value.list().type(i), dtype_module));
414     }
415   } else if (value.list().shape_size() > 0) {
416     for (int i = 0; i < value.list().shape_size(); ++i) {
417       if (i > 0) strings::StrAppend(&ret, ", ");
418       strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
419     }
420   } else if (value.list().tensor_size() > 0) {
421     for (int i = 0; i < value.list().tensor_size(); ++i) {
422       if (i > 0) strings::StrAppend(&ret, ", ");
423       strings::StrAppend(&ret, TensorToPython(value.list().tensor(i)));
424     }
425   } else if (value.list().func_size() > 0) {
426     for (int i = 0; i < value.list().func_size(); ++i) {
427       if (i > 0) strings::StrAppend(&ret, ", ");
428       strings::StrAppend(&ret, StringToPython(value.list().func(i).name()));
429     }
430   }
431   return ret;
432 }
433 
434 // NOTE: The return value may contain spaces (for example, it could be
435 // a string "foo bar" with an embedded space) and is not safe to pass
436 // to WordWrap().
AttrValueToPython(const string & type,const AttrValue & value,const string & dtype_module)437 string AttrValueToPython(const string& type, const AttrValue& value,
438                          const string& dtype_module) {
439   if (type == "string") {
440     return StringToPython(value.s());
441   } else if (type == "int") {
442     return strings::StrCat(value.i());
443   } else if (type == "float") {
444     if (std::isnan(value.f()) || std::isinf(value.f())) {
445       return strings::StrCat("float('", value.f(), "')");
446     } else {
447       // Use locale-independent conversion.
448       static_assert(FLT_DIG < 10, "FLT_DIG is too big");
449       std::ostringstream s;
450       s.imbue(std::locale::classic());
451       s << std::setprecision(FLT_DIG) << value.f();
452       // If there is no I/O error for `std::ostringstream s` return s.str(),
453       // otherwise fallback to strings::StrCat(value.f()).
454       if (s.good()) {
455         return s.str();
456       }
457       return strings::StrCat(value.f());
458     }
459   } else if (type == "bool") {
460     return value.b() ? "True" : "False";
461   } else if (type == "type") {
462     return DataTypeToPython(value.type(), dtype_module);
463   } else if (type == "shape") {
464     return ShapeToPython(value.shape());
465   } else if (type == "tensor") {
466     return TensorToPython(value.tensor());
467   } else if (type == "func") {
468     return StringToPython(value.func().name());
469   } else if (absl::StartsWith(type, "list(")) {
470     return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
471   } else {
472     return "?";
473   }
474 }
475 
GenerateLowerCaseOpName(const string & str,string * result)476 void GenerateLowerCaseOpName(const string& str, string* result) {
477   const char joiner = '_';
478   const char namespace_separator = '>';
479   const int last_index = str.size() - 1;
480   for (int i = 0; i <= last_index; ++i) {
481     const char c = str[i];
482     // Convert namespace separators ('>' characters) to joiners
483     if (c == namespace_separator) {
484       result->push_back(joiner);
485       continue;
486     }
487 
488     // Emit a joiner only if a previous-lower-to-now-upper or a
489     // now-upper-to-next-lower transition happens.
490     // (But don't emit an extra joiner if we just saw a namespace separator
491     if (isupper(c) && (i > 0)) {
492       if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
493         if (!(str[i - 1] == namespace_separator)) {
494           result->push_back(joiner);
495         }
496       }
497     }
498     result->push_back(tolower(c));
499   }
500 }
501 
AddDelimiter(string * append_to,const string & delim)502 static void AddDelimiter(string* append_to, const string& delim) {
503   if (!append_to->empty()) strings::StrAppend(append_to, delim);
504 }
505 
FindAttr(StringPiece name,const ApiDef & api_def)506 const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
507   for (int i = 0; i < api_def.attr_size(); ++i) {
508     if (api_def.attr(i).name() == name) {
509       return &api_def.attr(i);
510     }
511   }
512   return nullptr;
513 }
514 
GenPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)515 GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
516                          const string& function_name, bool add_type_annotations)
517     : op_def_(op_def),
518       api_def_(api_def),
519       function_name_(function_name),
520       add_type_annotations_(add_type_annotations),
521       num_outs_(op_def.output_arg_size()) {}
522 
~GenPythonOp()523 GenPythonOp::~GenPythonOp() {}
524 
Code()525 string GenPythonOp::Code() {
526   // This has all the input args followed by those attrs that don't have
527   // defaults.
528   std::vector<ParamNames> params_no_default;
529   // The parameters with defaults (these have to be listed after those without).
530   // No input args are included, just attrs.
531   std::vector<ParamNames> params_with_default;
532 
533   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
534     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
535     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
536     params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
537     if (!arg.type_attr().empty()) {
538       gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
539     } else if (!arg.type_list_attr().empty()) {
540       gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
541                               arg.name());
542     }
543     if (!arg.number_attr().empty()) {
544       gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
545     }
546   }
547   for (int i = 0; i < api_def_.attr_size(); ++i) {
548     const auto& attr(api_def_.attr(i));
549     // Do not add inferred attrs to the Python function signature.
550     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
551       if (attr.has_default_value()) {
552         params_with_default.emplace_back(attr.name(), attr.rename_to());
553       } else {
554         params_no_default.emplace_back(attr.name(), attr.rename_to());
555       }
556     }
557   }
558 
559   // Save the list of attr parameters (attrs that won't be inferred),
560   // those with defaults go at the end.
561   // Get the attrs in the order we want by taking the attrs without defaults
562   // from the end of args_no_default, and adding args_no_default.
563   attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
564                  params_with_default.size());
565   for (int i = op_def_.input_arg_size(), end = params_no_default.size();
566        i < end; ++i) {
567     attrs_.push_back(params_no_default[i].GetName());
568   }
569   for (int i = 0, end = params_with_default.size(); i < end; ++i) {
570     attrs_.push_back(params_with_default[i].GetName());
571   }
572 
573   param_names_.reserve(params_no_default.size() + params_with_default.size());
574   param_names_.insert(param_names_.begin(), params_no_default.begin(),
575                       params_no_default.end());
576   for (const auto& param : params_with_default) {
577     param_names_.push_back(param);
578   }
579 
580   string parameters;
581   for (const auto& param : params_no_default) {
582     AddDelimiter(&parameters, ", ");
583     strings::StrAppend(&parameters, param.GetRenameTo());
584   }
585   for (const auto& param_and_default : params_with_default) {
586     AddDelimiter(&parameters, ", ");
587     strings::StrAppend(&parameters, param_and_default.GetRenameTo(), "=None");
588   }
589   AddDelimiter(&parameters, ", ");
590   strings::StrAppend(&parameters, "name=None");
591 
592   AddExport();
593   AddDefLine(parameters);
594   AddDocStringDescription();
595   AddDocStringArgs();
596   AddDocStringInputs();
597   AddDocStringAttrs();
598   AddDocStringNameArg();
599   AddOutputGlobals();
600   AddDocStringOutputs();
601   strings::StrAppend(&result_, "  \"\"\"\n");
602   AddBody("  ");
603   strings::StrAppend(&result_, "\n\n");
604 
605   return prelude_ + result_;
606 }
607 
AddExport()608 void GenPythonOp::AddExport() {
609   if (api_def_.visibility() != ApiDef::VISIBLE) {
610     return;
611   }
612   // Whether op should be available in latest export version.
613   bool op_available_in_latest =
614       !api_def_.deprecation_version() ||
615       api_def_.deprecation_version() > kLatestAPIExportVersion;
616 
617   string names;
618   string names_v1;
619   string deprecated_endpoints;
620 
621   for (const auto& endpoint : api_def_.endpoint()) {
622     string endpoint_name;
623     python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(),
624                                                     &endpoint_name);
625     if (endpoint.deprecated() || endpoint.deprecation_version() > 0) {
626       AddDelimiter(&deprecated_endpoints, ", ");
627       strings::StrAppend(&deprecated_endpoints, "'", endpoint_name, "'");
628     }
629     // Add all endpoints to TensorFlow 1.* API.
630     AddDelimiter(&names_v1, ", ");
631     strings::StrAppend(&names_v1, "'", endpoint_name, "'");
632     // Add non-deprecated endpoints to TensorFlow 2.* API.
633     if (op_available_in_latest &&
634         (!endpoint.deprecation_version() ||
635          endpoint.deprecation_version() > kLatestAPIExportVersion)) {
636       AddDelimiter(&names, ", ");
637       strings::StrAppend(&names, "'", endpoint_name, "'");
638     }
639   }
640 
641   // tf_export decorator has the following format:
642   // @tf_export(v2_name, v2_name, v1=[v1_name, v1_name])
643   if (names != names_v1) {
644     AddDelimiter(&names, ", ");
645     strings::StrAppend(&names, "v1=[", names_v1, "]");
646   }
647   strings::StrAppend(&result_, "@tf_export(", names, ")\n");
648 
649   // If all endpoints are deprecated, add @deprecated decorator.
650   if (!api_def_.deprecation_message().empty()) {
651     const string instructions = api_def_.deprecation_message();
652     strings::StrAppend(&result_, "@deprecated(None, '", instructions, "')\n");
653   }
654   // Add @deprecated_endpoints decorator.
655   if (!deprecated_endpoints.empty()) {
656     strings::StrAppend(&result_, "@deprecated_endpoints(", deprecated_endpoints,
657                        ")\n");
658   }
659 }
660 
AddDefLine(const string & function_name,const string & parameters)661 void GenPythonOp::AddDefLine(const string& function_name,
662                              const string& parameters) {
663   strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n");
664 }
665 
AddDefLine(const string & parameters)666 void GenPythonOp::AddDefLine(const string& parameters) {
667   AddDefLine(function_name_, parameters);
668 }
669 
AddDocStringDescription()670 void GenPythonOp::AddDocStringDescription() {
671   string comment;
672   if (api_def_.summary().empty()) {
673     comment = "TODO: add doc.\n";
674   } else {
675     comment = strings::StrCat(api_def_.summary(), "\n");
676     if (!api_def_.description().empty()) {
677       strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description()));
678     }
679   }
680   strings::StrAppend(&result_, "  r\"\"\"", comment, "\n");
681 }
682 
AddDocStringArgs()683 void GenPythonOp::AddDocStringArgs() {
684   strings::StrAppend(&result_, "  Args:\n");
685 }
686 
AddDocStringInputs()687 void GenPythonOp::AddDocStringInputs() {
688   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
689     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
690     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
691     StringPiece description = api_def_arg.description();
692     string desc;
693     if (ConsumeEquals(&description)) {  // Skip the generated type info.
694       desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
695     } else {
696       desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
697                              ArgTypeName(op_def_, arg, inferred_attrs_, false));
698     }
699     if (!description.empty()) {
700       AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
701     }
702     strings::StrAppend(&result_, Indent(4, 6, desc));
703   }
704 }
705 
AddDocStringAttrs()706 void GenPythonOp::AddDocStringAttrs() {
707   for (const string& name : attrs_) {
708     const auto& attr = *FindAttr(name, op_def_);
709     const auto& api_def_attr = *FindAttr(name, api_def_);
710     string desc =
711         strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": ");
712 
713     static const char* const kAttrTypeName[][2] = {
714         {"string", "`string`"},
715         {"list(string)", "list of `strings`"},
716         {"int", "`int`"},
717         {"list(int)", "list of `ints`"},
718         {"float", "`float`"},
719         {"list(float)", "list of `floats`"},
720         {"bool", "`bool`"},
721         {"list(bool)", "list of `bools`"},
722         {"type", "`tf.DType`"},
723         {"list(type)", "list of `tf.DTypes`"},
724         {"shape", "`tf.TensorShape` or list of `ints`"},
725         {"list(shape)",
726          "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
727         {"tensor", "`tf.TensorProto`"},
728         {"list(tensor)", "list of `tf.TensorProto` objects"},
729         {"func", "function decorated with @Defun"},
730         {"list(func)", "list of functions decorated with @Defun"},
731     };
732     for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
733       if (attr.type() == kAttrTypeName[i][0]) {
734         string s;
735         if (api_def_attr.has_default_value()) {
736           s = strings::StrCat("optional ", kAttrTypeName[i][1]);
737         } else {
738           s = kAttrTypeName[i][1];
739         }
740         if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
741           strings::StrAppend(&desc, "An ", s);
742         } else {
743           strings::StrAppend(&desc, "A ", s);
744         }
745         break;
746       }
747     }
748 
749     if (attr.has_allowed_values()) {
750       strings::StrAppend(&desc, " from: `",
751                          AttrListToPython(attr.allowed_values()), "`");
752     }
753 
754     if (attr.has_minimum()) {
755       if (attr.type() == "int") {
756         strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
757       } else if (attr.minimum() > 0) {
758         strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
759       }
760     }
761 
762     strings::StrAppend(&desc, ".");
763 
764     if (api_def_attr.has_default_value()) {
765       strings::StrAppend(
766           &desc, " Defaults to `",
767           AttrValueToPython(attr.type(), api_def_attr.default_value()), "`.");
768     }
769     if (!api_def_attr.description().empty()) {
770       AppendWithinWidth(&desc, api_def_attr.description(),
771                         kRightMargin - 4 /* indent */);
772     }
773     strings::StrAppend(&result_, Indent(4, 6, desc));
774   }
775 }
776 
AddDocStringNameArg()777 void GenPythonOp::AddDocStringNameArg() {
778   strings::StrAppend(&result_,
779                      "    name: A name for the operation (optional).\n");
780 }
781 
AddOutputGlobals()782 void GenPythonOp::AddOutputGlobals() {
783   // Generate a namedtuple class to hold the outputs, if there are multiple.
784   // Example:
785   //
786   // _OpOutputs = collections.namedtuple(
787   //     "_OpOutputs",
788   //     "out1 out2 out3")
789   if (num_outs_ > 1) {
790     std::vector<string> out_names;
791     out_names.reserve(num_outs_);
792     for (int i = 0; i < num_outs_; ++i) {
793       const string out_name = !api_def_.out_arg(i).rename_to().empty()
794                                   ? api_def_.out_arg(i).rename_to()
795                                   : strings::StrCat("output", i);
796       out_names.push_back(strings::StrCat("\"", out_name, "\""));
797     }
798 
799     strings::StrAppend(&prelude_, "_", AvoidPythonReserved(op_def_.name()),
800                        "Output = collections.namedtuple(\n");
801     strings::StrAppend(&prelude_, "    \"", AvoidPythonReserved(op_def_.name()),
802                        "\",\n");
803     strings::StrAppend(&prelude_, "    [", absl::StrJoin(out_names, ", "),
804                        "])");
805     strings::StrAppend(&prelude_, "\n\n");
806   }
807   strings::StrAppend(&prelude_, "\n");
808 }
809 
AddDocStringOutputs()810 void GenPythonOp::AddDocStringOutputs() {
811   std::vector<string> output_type_string;
812   output_type_string.reserve(num_outs_);
813   for (int i = 0; i < num_outs_; ++i) {
814     output_type_string.push_back(
815         ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
816   }
817   strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
818 }
819 
AddBody(const string & prefix)820 void GenPythonOp::AddBody(const string& prefix) {
821   const string apply_prefix = strings::StrCat(
822       prefix, "_result = _op_def_lib.apply_op(\"", op_def_.name(), "\", ");
823   AddBodyNoReturn(apply_prefix);
824   if (num_outs_ > 1) {
825     strings::StrAppend(&result_, prefix, "_result = _",
826                        AvoidPythonReserved(op_def_.name()),
827                        "Output._make(_result)\n");
828   }
829   strings::StrAppend(&result_, prefix, "return _result\n");
830 }
831 
AddBodyNoReturn(const string & apply_prefix)832 void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
833   string args;
834   for (size_t i = 0; i < param_names_.size(); ++i) {
835     strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
836                        "=", param_names_[i].GetRenameTo(), ", ");
837   }
838   strings::StrAppend(&args, "name=name)");
839 
840   strings::StrAppend(&result_,
841                      // Wrap the arguments, and indent to the (.
842                      WordWrap(apply_prefix, args, kRightMargin), "\n");
843 }
844 
845 }  // namespace python_op_gen_internal
846 }  // namespace tensorflow
847