• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/python_op_gen.h"
16 
17 #include <stdio.h>
18 
19 #include <sstream>
20 #include <unordered_map>
21 
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/api_def.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/op_gen_lib.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/lib/strings/stringprintf.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/python/framework/python_op_gen_internal.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 const int kRightMargin = 78;
45 
46 constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
47 
48 // Maps C++ dtype enum values to Python DType classes
49 const std::unordered_map<string, string> dtype_type{
50     {"_dtypes.float16", "_dtypes.Float16"},
51     {"_dtypes.half", "_dtypes.Half"},
52     {"_dtypes.float32", "_dtypes.Float32"},
53     {"_dtypes.float64", "_dtypes.Float64"},
54     {"_dtypes.bfloat16", "_dtypes.BFloat16"},
55     {"_dtypes.complex64", "_dtypes.Complex64"},
56     {"_dtypes.complex128", "_dtypes.Complex128"},
57     {"_dtypes.int8", "_dtypes.Int8"},
58     {"_dtypes.uint8", "_dtypes.UInt8"},
59     {"_dtypes.uint16", "_dtypes.UInt16"},
60     {"_dtypes.uint32", "_dtypes.UInt32"},
61     {"_dtypes.uint64", "_dtypes.UInt64"},
62     {"_dtypes.int16", "_dtypes.Int16"},
63     {"_dtypes.int32", "_dtypes.Int32"},
64     {"_dtypes.int64", "_dtypes.Int64"},
65     {"_dtypes.bool", "_dtypes.Bool"},
66     {"_dtypes.string", "_dtypes.String"},
67     {"_dtypes.qint8", "_dtypes.QInt8"},
68     {"_dtypes.quint8", "_dtypes.QUInt8"},
69     {"_dtypes.qint16", "_dtypes.QInt16"},
70     {"_dtypes.quint16", "_dtypes.QUInt16"},
71     {"_dtypes.qint32", "_dtypes.QInt32"},
72     {"_dtypes.resource", "_dtypes.Resource"},
73     {"_dtypes.variant", "_dtypes.Variant"}};
74 
AttrVarName(const string & attr_name,std::unordered_map<string,string> * attr_expressions)75 string AttrVarName(const string& attr_name,
76                    std::unordered_map<string, string>* attr_expressions) {
77   const string var = strings::StrCat("_attr_", attr_name);
78   if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
79   return var;
80 }
81 
AddInferredAttr(const string & indentation,const string & attr_name,const string & value_expression,string * result,std::unordered_map<string,string> * attr_expressions)82 void AddInferredAttr(const string& indentation, const string& attr_name,
83                      const string& value_expression, string* result,
84                      std::unordered_map<string, string>* attr_expressions) {
85   strings::StrAppend(result, indentation,
86                      AttrVarName(attr_name, attr_expressions), " = ",
87                      value_expression, "\n");
88 }
89 
VectorToTuple(const std::vector<string> & l)90 string VectorToTuple(const std::vector<string>& l) {
91   if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
92   string ret = "(";
93   for (int i = 0, end = l.size(); i < end; ++i) {
94     if (i > 0) {
95       strings::StrAppend(&ret, ", ");
96     }
97     strings::StrAppend(&ret, l[i]);
98   }
99   strings::StrAppend(&ret, ")");
100   return ret;
101 }
102 
Unflatten(const string & prefix,const std::vector<string> & output_sizes,const string & var,string * result)103 void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
104                const string& var, string* result) {
105   for (int i = 0, end = output_sizes.size(); i < end; ++i) {
106     if (!output_sizes[i].empty()) {
107       strings::StrAppend(result, prefix, var, " = ");
108       if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
109       if (i + 1 < end) {
110         // Special case i == 0 to avoid "0 +" in the generated code.
111         if (i == 0) {
112           strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
113                              var, "[", output_sizes[i], ":]");
114         } else {
115           strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
116                              output_sizes[i], "]] + ", var, "[", i, " + ",
117                              output_sizes[i], ":]");
118         }
119       } else {
120         strings::StrAppend(result, "[", var, "[", i, ":]]");
121       }
122       strings::StrAppend(result, "\n");
123     }
124   }
125 }
126 
TensorPBString(const TensorProto & pb)127 string TensorPBString(const TensorProto& pb) {
128   // Note: This gets used in the argument list, and so must survive naive
129   // word wrapping.
130   return strings::StrCat("\"\"\"", pb.ShortDebugString(), "\"\"\"");
131 }
132 
133 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
134  public:
GenEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)135   GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
136                    const string& function_name, bool add_type_annotations)
137       : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name,
138                                             add_type_annotations) {
139     op_name_ = function_name_;
140     absl::ConsumePrefix(&op_name_, "_");
141   }
~GenEagerPythonOp()142   ~GenEagerPythonOp() override {}
143 
144   string Code() override;
145 
146  protected:
147   void HandleGraphMode(const string& function_setup,
148                        const std::vector<string>& output_sizes);
149 
150   string GetEagerNotAllowedError();
151   void ExpectListArg(const string& indentation, const string& arg_name,
152                      string* output);
153   bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
154   void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
155                                        string* num_outputs_expr);
156 
157   void AddEagerFunctionTeardown(const string& indentation,
158                                 const std::vector<string>& output_sizes,
159                                 bool execute_record_gradient);
160 
161   bool AddEagerFastPathAndGraphCode(
162       const string& parameters, const std::vector<string>& output_sizes,
163       const string& eager_not_allowed_error,
164       const std::unordered_map<string, string>& type_annotations);
165   bool AddEagerFallbackCode(
166       const string& parameters, const std::vector<string>& output_sizes,
167       const string& num_outputs_expr, const string& eager_not_allowed_error,
168       const std::unordered_map<string, string>& type_annotations);
169   void AddEagerFastPathExecute();
170 
171   void AddEagerInferredAttrs(const string& indentation);
172   void AddEagerInputCasts(const string& indentation);
173   void AddEagerAttrs(const string& indentation);
174   void AddEagerExecute(const string& indentation,
175                        const string& num_outputs_expr);
176   void AddDispatch(const string& prefix);
177 
178   void AddRawOpExport(const string& parameters);
179 
180   std::unordered_map<string, string> GetTypeAnnotations();
181 
182   void GenerateTypeVars(
183       const std::unordered_map<string, string>& type_annotations);
184 
185   void AddReturnTypeAnnotation(
186       const std::unordered_map<string, string>& type_annotations);
187 
AddAttrForArg(const string & attr,int arg_index)188   void AddAttrForArg(const string& attr, int arg_index) {
189     gtl::InsertIfNotPresent(&inferred_attrs_, attr,
190                             op_def_.input_arg(arg_index).name());
191     auto iter = attr_to_args_.find(attr);
192     if (iter == attr_to_args_.end()) {
193       attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
194     } else {
195       iter->second.push_back(arg_index);
196     }
197   }
198 
199   // Returns a string expression representing a flattened list of all
200   // the inputs given by `*input_indices` (or all inputs if
201   // `input_indices` is nullptr).  `*output_sizes` can be used to unflatten.
202   string FlattenInputs(const std::vector<int>* input_indices,
203                        std::vector<string>* output_sizes) const;
204 
205   StringPiece op_name_;
206   typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
207   AttrToArgMap attr_to_args_;
208   std::unordered_map<string, string> attr_expressions_;
209   // This has all the input args followed by those attrs that don't have
210   // defaults.
211   std::vector<python_op_gen_internal::ParamNames> params_no_default_;
212   // The parameters with defaults (these have to be listed after those without).
213   // No input args are included, just attrs.
214   std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
215       params_with_default_;
216 };
217 
GetEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)218 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
219                         const string& function_name,
220                         bool add_type_annotations) {
221   return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations)
222       .Code();
223 }
224 
FlattenInputs(const std::vector<int> * input_indices,std::vector<string> * output_sizes) const225 string GenEagerPythonOp::FlattenInputs(
226     const std::vector<int>* input_indices,
227     std::vector<string>* output_sizes) const {
228   string inputs;
229   enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
230   const int n = input_indices != nullptr ? input_indices->size()
231                                          : op_def_.input_arg_size();
232   for (int j = 0; j < n; ++j) {
233     const int i = input_indices ? (*input_indices)[j] : j;
234     const auto& arg(op_def_.input_arg(i));
235     const bool is_list =
236         !arg.type_list_attr().empty() || !arg.number_attr().empty();
237     if (is_list) {
238       if (inputs_state == WAS_SOLO_INPUT) {
239         strings::StrAppend(&inputs, "] + ");
240       } else if (inputs_state == WAS_LIST_INPUT) {
241         strings::StrAppend(&inputs, " + ");
242       }
243       strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
244       inputs_state = WAS_LIST_INPUT;
245       if (output_sizes != nullptr) {
246         if (!arg.number_attr().empty()) {
247           output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
248         } else {
249           output_sizes->emplace_back(
250               strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
251         }
252       }
253     } else {
254       if (inputs_state == WAS_SOLO_INPUT) {
255         strings::StrAppend(&inputs, ", ");
256       } else if (inputs_state == WAS_LIST_INPUT) {
257         strings::StrAppend(&inputs, " + [");
258       } else {
259         strings::StrAppend(&inputs, "[");
260       }
261       strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
262       inputs_state = WAS_SOLO_INPUT;
263       if (output_sizes != nullptr) output_sizes->emplace_back();
264     }
265   }
266   if (inputs_state == STARTING) return "[]";
267   if (inputs_state == WAS_SOLO_INPUT) {
268     strings::StrAppend(&inputs, "]");
269   }
270   return inputs;
271 }
272 
Code()273 string GenEagerPythonOp::Code() {
274   if (api_def_.visibility() == ApiDef::SKIP) {
275     return "";
276   }
277 
278   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
279     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
280     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
281     params_no_default_.emplace_back(api_def_arg.name(),
282                                     api_def_arg.rename_to());
283     if (!arg.type_attr().empty()) {
284       AddAttrForArg(arg.type_attr(), i);
285     } else if (!arg.type_list_attr().empty()) {
286       AddAttrForArg(arg.type_list_attr(), i);
287     }
288     if (!arg.number_attr().empty()) {
289       AddAttrForArg(arg.number_attr(), i);
290     }
291   }
292   for (int i = 0; i < op_def_.attr_size(); ++i) {
293     const auto& attr(op_def_.attr(i));
294     const auto& api_def_attr(api_def_.attr(i));
295     // Do not add inferred attrs to the Python function signature.
296     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
297       if (api_def_attr.has_default_value()) {
298         if (attr.type() == "tensor") {
299           params_with_default_.emplace_back(
300               python_op_gen_internal::ParamNames(api_def_attr.name(),
301                                                  api_def_attr.rename_to()),
302               strings::StrCat(
303                   "_execute.make_tensor(",
304                   TensorPBString(api_def_attr.default_value().tensor()), ", \"",
305                   api_def_attr.rename_to(), "\")"));
306         } else if (attr.type() == "list(tensor)") {
307           std::vector<string> pbtxt;
308           for (const auto& pb : api_def_attr.default_value().list().tensor()) {
309             pbtxt.emplace_back(TensorPBString(pb));
310           }
311           params_with_default_.emplace_back(
312               python_op_gen_internal::ParamNames(api_def_attr.name(),
313                                                  api_def_attr.rename_to()),
314               strings::StrCat("[_execute.make_tensor(_pb, \"",
315                               api_def_attr.rename_to(), "\") for _pb in ",
316                               VectorToTuple(pbtxt), "]"));
317         } else {
318           params_with_default_.emplace_back(
319               python_op_gen_internal::ParamNames(api_def_attr.name(),
320                                                  api_def_attr.rename_to()),
321               python_op_gen_internal::AttrValueToPython(
322                   attr.type(), api_def_attr.default_value(), "_dtypes."));
323         }
324       } else {
325         params_no_default_.emplace_back(api_def_attr.name(),
326                                         api_def_attr.rename_to());
327       }
328     }
329   }
330 
331   // Save the list of attr parameters (attrs that won't be inferred),
332   // those with defaults go at the end.
333   // Get the attrs in the order we want by taking the attrs without defaults
334   // from the end of params_no_default_, and adding params_no_default_.
335   attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
336                  params_with_default_.size());
337   for (int i = op_def_.input_arg_size(), end = params_no_default_.size();
338        i < end; ++i) {
339     attrs_.push_back(params_no_default_[i].GetName());
340   }
341   for (const auto& p : params_with_default_) {
342     attrs_.push_back(p.first.GetName());
343   }
344 
345   // TODO(slebedev): call AvoidPythonReserved on each param?
346   param_names_.reserve(params_no_default_.size() + params_with_default_.size());
347   param_names_.insert(param_names_.begin(), params_no_default_.begin(),
348                       params_no_default_.end());
349   for (const auto& param_and_default : params_with_default_) {
350     param_names_.push_back(param_and_default.first);
351   }
352 
353   std::unordered_map<string, string> type_annotations;
354   // Only populate map for allowlisted ops
355   if (add_type_annotations_) {
356     type_annotations = GetTypeAnnotations();
357   }
358 
359   string parameters;
360   // Param can be an input or an attr
361   for (const auto& param : params_no_default_) {
362     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
363     strings::StrAppend(&parameters, param.GetRenameTo());
364 
365     if (type_annotations.find(param.GetName()) != type_annotations.end()) {
366       strings::StrAppend(&parameters, ": ",
367                          type_annotations.at(param.GetName()));
368     }
369   }
370 
371   string parameters_with_defaults = parameters;
372   for (const auto& param_and_default : params_with_default_) {
373     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
374     if (!parameters_with_defaults.empty())
375       strings::StrAppend(&parameters_with_defaults, ", ");
376 
377     strings::StrAppend(&parameters, param_and_default.first.GetRenameTo());
378     strings::StrAppend(&parameters_with_defaults,
379                        param_and_default.first.GetRenameTo());
380     if (type_annotations.find(param_and_default.first.GetName()) !=
381         type_annotations.end()) {
382       const string param_type =
383           type_annotations.at(param_and_default.first.GetName());
384       // Append to parameters and parameters_with_defaults because multiple
385       // functions are generated by AddEagerFastPathAndGraphCode() and
386       // AddEagerFallbackCode()
387       strings::StrAppend(&parameters, ": ", param_type);
388       strings::StrAppend(&parameters_with_defaults, ":", param_type);
389     }
390 
391     strings::StrAppend(&parameters_with_defaults, "=",
392                        param_and_default.second);
393   }
394 
395   strings::StrAppend(&parameters, parameters.empty() ? "" : ", ", "name");
396   strings::StrAppend(&parameters_with_defaults,
397                      parameters_with_defaults.empty() ? "" : ", ", "name=None");
398 
399   // Add attr_expressions_ for attrs that are params.
400   for (int i = 0, end = attrs_.size(); i < end; ++i) {
401     const string& attr_name = attrs_[i];
402     const string& attr_api_name =
403         param_names_[i + op_def_.input_arg_size()].GetRenameTo();
404     attr_expressions_[attr_name] = attr_api_name;
405   }
406   // Add attr_expressions_ for attrs that are inferred.
407   for (int i = 0; i < op_def_.attr_size(); ++i) {
408     const auto& attr(op_def_.attr(i));
409     if (attr.type() == "int") {
410       auto arg_list = attr_to_args_.find(attr.name());
411       if (arg_list != attr_to_args_.end()) {
412         AttrVarName(attr.name(), &attr_expressions_);
413       }
414     }
415   }
416 
417   string num_outputs_expr;
418   std::vector<string> output_sizes(num_outs_);
419   GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
420 
421   string eager_not_allowed_error = GetEagerNotAllowedError();
422 
423   if (!AddEagerFastPathAndGraphCode(parameters_with_defaults, output_sizes,
424                                     eager_not_allowed_error,
425                                     type_annotations)) {
426     return result_;
427   }
428 
429   if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
430                             eager_not_allowed_error, type_annotations)) {
431     return result_;
432   }
433 
434   return prelude_ + result_;
435 }
436 
GetTypeAnnotations()437 std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
438   std::unordered_map<string, string> type_annotations;
439   // Map attrs to TypeVars
440   for (const auto& attr : op_def_.attr()) {
441     if (attr.type() == "type") {
442       const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
443       type_annotations[attr.name()] = type_var_name;
444     } else if (attr.type() == "bool" || attr.type() == "float" ||
445                attr.type() == "int" || attr.type() == "bytes") {
446       type_annotations[attr.name()] = attr.type();
447     } else if (attr.type() == "string") {
448       type_annotations[attr.name()] = "str";
449     }
450   }
451 
452   // Map input Tensors to their types
453   for (const auto& arg : op_def_.input_arg()) {
454     // TODO(rahulkamat): Add type annotations to args that accept a sequence of
455     // Tensors
456     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
457     type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
458   }
459 
460   // TODO(rahulkamat): Add type annotations to handle return types of a sequence
461   // of Tensors. Map output Tensor to its type
462   if (op_def_.output_arg_size() == 1) {
463     const auto& arg = op_def_.output_arg(0);
464     if (arg.number_attr().empty() && arg.type_list_attr().empty())
465       type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
466   }
467 
468   return type_annotations;
469 }
470 
471 // Generate TypeVars using attrs
GenerateTypeVars(const std::unordered_map<string,string> & type_annotations)472 void GenEagerPythonOp::GenerateTypeVars(
473     const std::unordered_map<string, string>& type_annotations) {
474   bool added_typevar = false;
475   for (const auto& attr : op_def_.attr()) {
476     if (attr.type() == "type") {
477       std::vector<string> allowed_types;
478       for (int t : attr.allowed_values().list().type()) {
479         DataType dtype = static_cast<DataType>(t);
480         const string py_dtype =
481             python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
482         allowed_types.emplace_back(dtype_type.at(py_dtype));
483       }
484 
485       // When a Tensor does not have any dtypes specified, all dtypes are
486       // allowed
487       if (allowed_types.empty()) {
488         for (std::pair<string, string> map_dtype : dtype_type) {
489           allowed_types.emplace_back(map_dtype.second);
490         }
491       }
492 
493       std::sort(allowed_types.begin(), allowed_types.end());
494 
495       string typevar_dtypes;
496       for (std::vector<string>::iterator it = allowed_types.begin();
497            it != allowed_types.end(); ++it) {
498         if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", ");
499         strings::StrAppend(&typevar_dtypes, *it);
500       }
501 
502       const string type_var_name = type_annotations.at(attr.name());
503       strings::StrAppend(&result_, type_var_name, " = TypeVar(\"",
504                          type_var_name, "\", ", typevar_dtypes, ")\n");
505       added_typevar = true;
506     }
507   }
508 
509   if (added_typevar) strings::StrAppend(&result_, "\n");
510 }
511 
AddReturnTypeAnnotation(const std::unordered_map<string,string> & type_annotations)512 void GenEagerPythonOp::AddReturnTypeAnnotation(
513     const std::unordered_map<string, string>& type_annotations) {
514   if (op_def_.output_arg_size() == 1) {
515     const auto& arg = op_def_.output_arg(0);
516     if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
517       const string return_type = type_annotations.at(arg.name());
518       // TODO(rahulkamat): Modify AddDefLine() to add return type annotation to
519       // avoid erasing ":\n" from the end of the def line
520       result_.erase(result_.length() - 2);
521       strings::StrAppend(&result_, " -> ", return_type, ":\n");
522     }
523   }
524 }
525 
HandleGraphMode(const string & function_setup,const std::vector<string> & output_sizes)526 void GenEagerPythonOp::HandleGraphMode(
527     const string& function_setup, const std::vector<string>& output_sizes) {
528   strings::StrAppend(&result_, "  # Add nodes to the TensorFlow graph.\n");
529   strings::StrAppend(&result_, function_setup);
530   if (api_def_.visibility() == ApiDef::VISIBLE) {
531     strings::StrAppend(&result_, "  try:\n  ");
532   }
533   strings::StrAppend(
534       &result_, "  _, _, _op, _outputs = _op_def_library._apply_op_helper(\n");
535   AddBodyNoReturn(strings::StrCat("        \"", op_def_.name(), "\", "));
536   AddDispatch("  ");
537 
538   if (num_outs_ > 0) {
539     strings::StrAppend(&result_, "  _result = _outputs[:]\n");
540     // Special case handling for stateful op with single list output
541     // that might be empty.
542     if (num_outs_ == 1 && op_def_.is_stateful() &&
543         (!op_def_.output_arg(0).number_attr().empty() ||
544          !op_def_.output_arg(0).type_list_attr().empty())) {
545       // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
546       // a constraint indicating that this can never be empty.
547       strings::StrAppend(&result_,
548                          "  if not _result:\n"
549                          "    return _op\n");
550     }
551 
552     // Compute graph-mode attrs when we need to record a gradient.
553     strings::StrAppend(&result_, "  if _execute.must_record_gradient():\n");
554     if (op_def_.attr_size() > 0) {
555       string attr_values;
556       for (int i = 0; i < op_def_.attr_size(); ++i) {
557         if (i > 0) strings::StrAppend(&attr_values, ", ");
558         const auto& attr_name(op_def_.attr(i).name());
559         if (op_def_.attr(i).type() == "type") {
560           strings::StrAppend(&attr_values, "\"", attr_name,
561                              "\", _op._get_attr_type(\"", attr_name, "\")");
562         } else if (op_def_.attr(i).type() == "bool") {
563           strings::StrAppend(&attr_values, "\"", attr_name,
564                              "\", _op._get_attr_bool(\"", attr_name, "\")");
565         } else if (op_def_.attr(i).type() == "int") {
566           strings::StrAppend(&attr_values, "\"", attr_name,
567                              "\", _op._get_attr_int(\"", attr_name, "\")");
568         } else {
569           strings::StrAppend(&attr_values, "\"", attr_name,
570                              "\", _op.get_attr(\"", attr_name, "\")");
571         }
572       }
573       strings::StrAppend(&attr_values, ")");
574       strings::StrAppend(&result_,
575                          WordWrap("    _attrs = (", attr_values, kRightMargin),
576                          "\n");
577 
578     } else {
579       strings::StrAppend(&result_, "    _attrs = ()\n");
580     }
581 
582     strings::StrAppend(&result_, "    _inputs_flat = _op.inputs\n");
583     strings::StrAppend(&result_, "    _execute.record_gradient(\n",
584                        "        \"", op_def_.name(),
585                        "\", _inputs_flat, _attrs, _result)\n");
586 
587     if (num_outs_ == 1 && !output_sizes[0].empty()) {
588       // Single list result.
589     } else if (num_outs_ == 1) {
590       // Execute returns a single-element list which we need to destructure.
591       strings::StrAppend(&result_, "  ", "_result, = _result\n");
592     } else {
593       // Have multiple outputs, so we will need to reformat the return
594       // value of execute() to be a list with one entry per op output
595       // (that entry will be a list of tensors if that output is of list
596       // type).
597       // For list outputs, convert the right subrange of _result into a list.
598       Unflatten("  ", output_sizes, "_result", &result_);
599       // Convert to a named tuple.
600       strings::StrAppend(
601           &result_, "  _result = _",
602           python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
603           "Output._make(_result)\n");
604     }
605     strings::StrAppend(&result_, "  return _result\n\n");
606   } else {
607     strings::StrAppend(&result_, "  return _op\n");
608   }
609 }
610 
GetEagerNotAllowedError()611 string GenEagerPythonOp::GetEagerNotAllowedError() {
612   bool eager_allowed = true;
613   string ref_arg;
614   for (int i = 0; i < op_def_.input_arg_size(); ++i) {
615     const auto& arg = op_def_.input_arg(i);
616     if (arg.is_ref()) {
617       eager_allowed = false;
618       DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
619       ref_arg = api_def_.in_arg(i).rename_to();
620     }
621   }
622   for (int i = 0; i < op_def_.output_arg_size(); ++i) {
623     const auto& arg = op_def_.output_arg(i);
624     if (arg.is_ref()) {
625       eager_allowed = false;
626       DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
627       ref_arg = api_def_.out_arg(i).rename_to();
628     }
629   }
630 
631   if (eager_allowed) return "";
632 
633   return strings::StrCat("raise RuntimeError(\"", op_name_,
634                          " op does not support eager execution. ", "Arg '",
635                          ref_arg, "' is a ref.\")\n");
636 }
637 
ExpectListArg(const string & indentation,const string & arg_name,string * output)638 void GenEagerPythonOp::ExpectListArg(const string& indentation,
639                                      const string& arg_name, string* output) {
640   strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
641                      ", (list, tuple)):\n", indentation, "  raise TypeError(\n",
642                      indentation, "      \"Expected list for '", arg_name,
643                      "' argument to \"\n", indentation, "      \"'", op_name_,
644                      "' Op, not %r.\" % ", arg_name, ")\n");
645 }
646 
GetEagerFunctionSetup(const string & indentation,string * function_setup)647 bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
648                                              string* function_setup) {
649   // Validate list inputs, infer length attrs.
650   for (int i = 0; i < op_def_.attr_size(); ++i) {
651     const auto& attr(op_def_.attr(i));
652     if (attr.type() == "int") {
653       auto arg_list = attr_to_args_.find(attr.name());
654       if (arg_list != attr_to_args_.end()) {
655         // Inferred int attrs are the lengths of inputs. Validate those
656         // inputs are lists and have the same length.
657         for (auto iter = arg_list->second.begin();
658              iter != arg_list->second.end(); ++iter) {
659           const string& arg_api_name = param_names_[*iter].GetRenameTo();
660           ExpectListArg(indentation, arg_api_name, function_setup);
661           if (iter == arg_list->second.begin()) {
662             AddInferredAttr(indentation, attr.name(),
663                             strings::StrCat("len(", arg_api_name, ")"),
664                             function_setup, &attr_expressions_);
665           } else {
666             const auto& attr_var = attr_expressions_[attr.name()];
667             strings::StrAppend(
668                 function_setup, indentation, "if len(", arg_api_name,
669                 ") != ", attr_var, ":\n", indentation, "  raise ValueError(\n",
670                 indentation, "      \"List argument '", arg_api_name, "' to '",
671                 op_name_, "' Op with length %d \"\n", indentation,
672                 "      \"must match length %d of argument '",
673                 inferred_attrs_[attr.name()], "'.\" %\n", indentation,
674                 "      (len(", arg_api_name, "), ", attr_var, "))\n");
675           }
676         }
677       }
678     }
679   }
680 
681   for (int i = 0, end = attrs_.size(); i < end; ++i) {
682     const string& attr_name = attrs_[i];
683     const auto& param = param_names_[i + op_def_.input_arg_size()];
684     const auto& attr = *FindAttr(attr_name, op_def_);
685     const string& attr_api_name = param.GetRenameTo();
686     StringPiece attr_type = attr.type();
687     attr_expressions_[attr_name] = attr_api_name;
688     const int default_index = i - (attrs_.size() - params_with_default_.size());
689     if (default_index >= 0) {
690       const string& default_value = params_with_default_[default_index].second;
691       strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
692                          " is None:\n");
693       strings::StrAppend(function_setup, indentation, "  ", attr_api_name,
694                          " = ", default_value, "\n");
695     }
696     if (absl::StartsWith(attr_type, "list(")) {
697       ExpectListArg(indentation, attr_api_name, function_setup);
698     }
699 
700     if (attr_type == "string") {
701       strings::StrAppend(function_setup, indentation, attr_api_name,
702                          " = _execute.make_str(", attr_api_name, ", \"",
703                          attr_api_name, "\")\n");
704     } else if (attr_type == "list(string)") {
705       strings::StrAppend(function_setup, indentation, attr_api_name,
706                          " = [_execute.make_str(_s, \"", attr_api_name,
707                          "\") for _s in ", attr_api_name, "]\n");
708     } else if (attr_type == "int") {
709       strings::StrAppend(function_setup, indentation, attr_api_name,
710                          " = _execute.make_int(", attr_api_name, ", \"",
711                          attr_api_name, "\")\n");
712     } else if (attr_type == "list(int)") {
713       strings::StrAppend(function_setup, indentation, attr_api_name,
714                          " = [_execute.make_int(_i, \"", attr_api_name,
715                          "\") for _i in ", attr_api_name, "]\n");
716     } else if (attr_type == "float") {
717       strings::StrAppend(function_setup, indentation, attr_api_name,
718                          " = _execute.make_float(", attr_api_name, ", \"",
719                          attr_api_name, "\")\n");
720     } else if (attr_type == "list(float)") {
721       strings::StrAppend(function_setup, indentation, attr_api_name,
722                          " = [_execute.make_float(_f, \"", attr_api_name,
723                          "\") for _f in ", attr_api_name, "]\n");
724     } else if (attr_type == "bool") {
725       strings::StrAppend(function_setup, indentation, attr_api_name,
726                          " = _execute.make_bool(", attr_api_name, ", \"",
727                          attr_api_name, "\")\n");
728     } else if (attr_type == "list(bool)") {
729       strings::StrAppend(function_setup, indentation, attr_api_name,
730                          " = [_execute.make_bool(_b, \"", attr_api_name,
731                          "\") for _b in ", attr_api_name, "]\n");
732     } else if (attr_type == "type") {
733       strings::StrAppend(function_setup, indentation, attr_api_name,
734                          " = _execute.make_type(", attr_api_name, ", \"",
735                          attr_api_name, "\")\n");
736     } else if (attr_type == "list(type)") {
737       strings::StrAppend(function_setup, indentation, attr_api_name,
738                          " = [_execute.make_type(_t, \"", attr_api_name,
739                          "\") for _t in ", attr_api_name, "]\n");
740     } else if (attr_type == "shape") {
741       strings::StrAppend(function_setup, indentation, attr_api_name,
742                          " = _execute.make_shape(", attr_api_name, ", \"",
743                          attr_api_name, "\")\n");
744     } else if (attr_type == "list(shape)") {
745       strings::StrAppend(function_setup, indentation, attr_api_name,
746                          " = [_execute.make_shape(_s, \"", attr_api_name,
747                          "\") for _s in ", attr_api_name, "]\n");
748     } else if (attr_type == "tensor") {
749       strings::StrAppend(function_setup, indentation, attr_api_name,
750                          " = _execute.make_tensor(", attr_api_name, ", \"",
751                          attr_api_name, "\")\n");
752     } else if (attr_type == "list(tensor)") {
753       strings::StrAppend(function_setup, indentation, attr_api_name,
754                          " = [_execute.make_tensor(_t, \"", attr_api_name,
755                          "\") for _t in ", attr_api_name, "]\n");
756     } else if (attr_type != "func" && attr_type != "list(func)") {
757       *function_setup =
758           strings::StrCat("# No definition for ", function_name_,
759                           " since we don't support attrs with type\n"
760                           "# '",
761                           attr_type, "' right now.\n\n");
762       return false;
763     }
764   }
765   return true;
766 }
767 
768 // If output i is list output, output_sizes[i] will be set to a
769 // string with the python expression that will evaluate to its
770 // length. output_sizes[i] is empty for non-list outputs.
GetOutputSizesAndNumOutputsExpr(std::vector<string> * output_sizes,string * num_outputs_expr)771 void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
772     std::vector<string>* output_sizes, string* num_outputs_expr) {
773   // Expression representing the number of outputs.
774   int num_fixed_outputs = 0;
775   for (int i = 0; i < num_outs_; ++i) {
776     const auto& arg(op_def_.output_arg(i));
777     if (!arg.number_attr().empty()) {
778       if (!num_outputs_expr->empty()) {
779         strings::StrAppend(num_outputs_expr, " + ");
780       }
781       (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
782       strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
783     } else if (!arg.type_list_attr().empty()) {
784       if (!num_outputs_expr->empty()) {
785         strings::StrAppend(num_outputs_expr, " + ");
786       }
787       // Have to be careful to use an expression that works in both
788       // graph and eager paths here.
789       const auto iter = inferred_attrs_.find(arg.type_list_attr());
790       if (iter == inferred_attrs_.end()) {
791         (*output_sizes)[i] = strings::StrCat(
792             "len(", attr_expressions_[arg.type_list_attr()], ")");
793       } else {
794         (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
795       }
796       strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
797     } else {
798       ++num_fixed_outputs;
799     }
800   }
801   if (num_fixed_outputs > 0) {
802     if (!num_outputs_expr->empty()) {
803       strings::StrAppend(num_outputs_expr, " + ");
804     }
805     strings::StrAppend(num_outputs_expr, num_fixed_outputs);
806   } else if (num_outputs_expr->empty()) {
807     *num_outputs_expr = "0";
808   }
809 }
810 
AddEagerFunctionTeardown(const string & indentation,const std::vector<string> & output_sizes,bool execute_record_gradient)811 void GenEagerPythonOp::AddEagerFunctionTeardown(
812     const string& indentation, const std::vector<string>& output_sizes,
813     bool execute_record_gradient) {
814   if (num_outs_ > 0) {
815     if (execute_record_gradient) {
816       strings::StrAppend(&result_, indentation,
817                          "if _execute.must_record_gradient():\n");
818       strings::StrAppend(&result_, indentation, "  _execute.record_gradient(\n",
819                          "        \"", op_def_.name(),
820                          "\", _inputs_flat, _attrs, _result)\n");
821     }
822     if (num_outs_ == 1 && !output_sizes[0].empty()) {
823       // Single list result.
824     } else if (num_outs_ == 1) {
825       // Execute returns a single-element list which we need to destructure.
826       strings::StrAppend(&result_, indentation, "_result, = _result\n");
827     } else {
828       // Have multiple outputs, so we will need to reformat the return
829       // value of execute() to be a list with one entry per op output
830       // (that entry will be a list of tensors if that output is of list
831       // type).
832       // For list outputs, convert the right subrange of _result into a list.
833       Unflatten(indentation, output_sizes, "_result", &result_);
834       // Convert to a named tuple.
835       strings::StrAppend(
836           &result_, indentation, "_result = _",
837           python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
838           "Output._make(_result)\n");
839     }
840   } else {
841     strings::StrAppend(&result_, indentation, "_result = None\n");
842   }
843   strings::StrAppend(&result_, indentation, "return _result\n\n");
844 }
845 
AddEagerFastPathAndGraphCode(const string & parameters,const std::vector<string> & output_sizes,const string & eager_not_allowed_error,const std::unordered_map<string,string> & type_annotations)846 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
847     const string& parameters, const std::vector<string>& output_sizes,
848     const string& eager_not_allowed_error,
849     const std::unordered_map<string, string>& type_annotations) {
850   if (add_type_annotations_) {
851     GenerateTypeVars(type_annotations);
852   }
853   if (api_def_.visibility() == ApiDef::VISIBLE) {
854     strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
855   }
856 
857   AddExport();
858   AddDefLine(function_name_, parameters);
859   if (add_type_annotations_) {
860     AddReturnTypeAnnotation(type_annotations);
861   }
862   AddDocStringDescription();
863   AddDocStringArgs();
864   AddDocStringInputs();
865   AddDocStringAttrs();
866   AddDocStringNameArg();
867   AddOutputGlobals();  // Added to prelude_
868   AddDocStringOutputs();
869   strings::StrAppend(&result_, "  \"\"\"\n");
870 
871   strings::StrAppend(&result_,
872                      "  _ctx = _context._context or _context.context()\n"
873                      "  tld = _ctx._thread_local_data\n",
874                      "  if tld.is_eager:", "\n");
875   if (eager_not_allowed_error.empty()) {
876     AddEagerFastPathExecute();
877   } else {
878     strings::StrAppend(&result_, "    ", eager_not_allowed_error);
879   }
880 
881   // Handle graph-mode case
882   string function_setup;
883   if (!GetEagerFunctionSetup("  ", &function_setup)) {
884     result_ = function_setup;
885     return false;
886   }
887   HandleGraphMode(function_setup, output_sizes);
888 
889   AddRawOpExport(parameters);
890   strings::StrAppend(&result_, "\n\n");
891   return true;
892 }
893 
AddEagerFallbackCode(const string & parameters,const std::vector<string> & output_sizes,const string & num_outputs_expr,const string & eager_not_allowed_error,const std::unordered_map<string,string> & type_annotations)894 bool GenEagerPythonOp::AddEagerFallbackCode(
895     const string& parameters, const std::vector<string>& output_sizes,
896     const string& num_outputs_expr, const string& eager_not_allowed_error,
897     const std::unordered_map<string, string>& type_annotations) {
898   AddDefLine(
899       strings::StrCat(function_name_, kEagerFallbackSuffix),
900       strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
901   if (add_type_annotations_) {
902     AddReturnTypeAnnotation(type_annotations);
903   }
904   if (!eager_not_allowed_error.empty()) {
905     strings::StrAppend(&result_, "  ", eager_not_allowed_error);
906     return true;
907   }
908 
909   string function_setup;
910   if (!GetEagerFunctionSetup("  ", &function_setup)) {
911     result_ = function_setup;
912     return false;
913   }
914   strings::StrAppend(&result_, function_setup);
915 
916   AddEagerInferredAttrs("  ");
917   AddEagerInputCasts("  ");
918   strings::StrAppend(
919       &result_, "  _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
920   AddEagerAttrs("  ");
921   AddEagerExecute("  ", num_outputs_expr);
922 
923   AddEagerFunctionTeardown("  ", output_sizes,
924                            true /* execute_record_gradient */);
925 
926   return true;
927 }
928 
AddEagerFastPathExecute()929 void GenEagerPythonOp::AddEagerFastPathExecute() {
930   string fastpath_execute_params =
931       strings::StrCat("_ctx, \"", op_def_.name(), "\", ", "name");
932   string fallback_params;
933 
934   for (int i = 0; i < api_def_.in_arg_size(); i++) {
935     const string param_name = param_names_[i].GetRenameTo();
936     strings::StrAppend(&fastpath_execute_params, ", ", param_name);
937     if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
938     strings::StrAppend(&fallback_params, param_name);
939   }
940 
941   for (const auto& attr : api_def_.attr()) {
942     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
943       strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
944                          attr.rename_to());
945 
946       if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
947       strings::StrAppend(&fallback_params, attr.rename_to(), "=",
948                          attr.rename_to());
949     }
950   }
951 
952   if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
953   strings::StrAppend(&fallback_params, "name=name");
954 
955   strings::StrAppend(&result_, "    try:\n");
956   strings::StrAppend(
957       &result_, "      ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
958       WordWrap(strings::StrCat("        "),
959                strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
960       "\n");
961 
962   if (op_def_.output_arg_size() > 1) {
963     const string output_tuple_name = strings::StrCat(
964         "_", python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
965         "Output");
966     strings::StrAppend(&result_, "      ", "_result = ", output_tuple_name,
967                        "._make(_result)\n");
968   }
969   strings::StrAppend(&result_, "      ", "return _result\n");
970 
971   // Handle fallback.
972   if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
973   strings::StrAppend(&fallback_params, "ctx=_ctx");
974 
975   // Any errors thrown from execute need to be unwrapped from
976   // _NotOkStatusException.
977   strings::StrAppend(&result_, "    ",
978                      "except _core._NotOkStatusException as e:\n");
979   strings::StrAppend(&result_, "      ",
980                      "_ops.raise_from_not_ok_status(e, name)\n");
981 
982   strings::StrAppend(&result_, "    ", "except _core._FallbackException:\n");
983   strings::StrAppend(&result_, "      pass\n");
984   strings::StrAppend(&result_, "    try:\n");
985   strings::StrAppend(
986       &result_, "      ", "return ", function_name_, kEagerFallbackSuffix,
987       "(\n",
988       WordWrap(strings::StrCat("          "),
989                strings::StrCat(fallback_params, ")"), kRightMargin),
990       "\n");
991   strings::StrAppend(&result_, "    except _core._SymbolicException:\n");
992   strings::StrAppend(&result_,
993                      "      pass  # Add nodes to the TensorFlow graph.\n");
994   AddDispatch("    ");
995 }
996 
AddEagerInferredAttrs(const string & indentation)997 void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
998   // Figure out values for inferred attrs, and cast to eager tensors.
999   for (int i = 0; i < op_def_.attr_size(); ++i) {
1000     const auto& attr(op_def_.attr(i));
1001     const auto& api_def_attr(api_def_.attr(i));
1002     auto arg_list = attr_to_args_.find(attr.name());
1003     if (arg_list != attr_to_args_.end()) {
1004       if (attr.type() == "type") {
1005         std::vector<string> output_sizes;
1006         const string flattened =
1007             FlattenInputs(&arg_list->second, &output_sizes);
1008         string conversion = strings::StrCat("_execute.args_to_matching_eager(",
1009                                             flattened, ", ctx");
1010 
1011         strings::StrAppend(&conversion, ", [");
1012         for (int t : attr.allowed_values().list().type()) {
1013           DataType dtype = static_cast<DataType>(t);
1014           const string py_dtype =
1015               python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
1016           strings::StrAppend(&conversion, py_dtype, ", ");
1017         }
1018         strings::StrAppend(&conversion, "]");
1019 
1020         if (attr.has_default_value()) {
1021           strings::StrAppend(
1022               &conversion, ", ",
1023               python_op_gen_internal::AttrValueToPython(
1024                   attr.type(), api_def_attr.default_value(), "_dtypes."));
1025         }
1026         strings::StrAppend(&conversion, ")");
1027         const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1028         if (output_sizes.size() == 1) {
1029           // Avoid creating a temporary variable in the case where
1030           // we can easily assign to the right value directly.
1031           const string inputs_var =
1032               param_names_[arg_list->second.front()].GetRenameTo();
1033           if (output_sizes.front().empty()) {
1034             strings::StrAppend(&result_, indentation, var_name, ", (",
1035                                inputs_var, ",) = ", conversion, "\n");
1036           } else {
1037             strings::StrAppend(&result_, indentation, var_name, ", ",
1038                                inputs_var, " = ", conversion, "\n");
1039           }
1040         } else {
1041           const string inputs_var = strings::StrCat("_inputs_", attr.name());
1042           strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1043                              " = ", conversion, "\n");
1044           // Convert from a flat list of eager tensors back to the
1045           // parameter variables.
1046           Unflatten(indentation, output_sizes, inputs_var, &result_);
1047           std::vector<string> p;
1048           for (int j : arg_list->second) {
1049             p.emplace_back(param_names_[j].GetRenameTo());
1050           }
1051           strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
1052                              inputs_var, "\n");
1053         }
1054       } else if (attr.type() == "list(type)") {
1055         // NOTE: We ignore default values for these attrs, since it is
1056         // unclear how you would use it, and the one use case is
1057         // parse_single_sequence_example which only needs it for
1058         // backwards compatibility.
1059         const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1060         string inputs_var;
1061         string conversion;
1062         if (arg_list->second.size() > 1) {
1063           // If you have more than one list(tensor) argument, their types
1064           // have to match.
1065           std::vector<string> lists;
1066           for (auto iter = arg_list->second.begin();
1067                iter != arg_list->second.end(); ++iter) {
1068             lists.push_back(param_names_[*iter].GetRenameTo());
1069           }
1070           inputs_var = VectorToTuple(lists);
1071           conversion = "_execute.args_to_mixed_eager_tensors";
1072         } else {
1073           // For one list(tensor) argument, we just convert every
1074           // element of the list to an eager tensor.
1075           inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
1076           conversion = "_execute.convert_to_mixed_eager_tensors";
1077         }
1078         strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1079                            " = ", conversion, "(", inputs_var, ", ctx)\n");
1080       }
1081     }
1082   }
1083 }
1084 
AddEagerInputCasts(const string & indentation)1085 void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
1086   // Cast remaining args to eager tensors
1087   for (int i = 0; i < op_def_.input_arg_size(); ++i) {
1088     const auto& arg(op_def_.input_arg(i));
1089     if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
1090     const string& param = param_names_[i].GetRenameTo();
1091     const string fn = arg.number_attr().empty() ? "" : "n_";
1092     const string dtype =
1093         python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1094     strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
1095                        "to_tensor(", param, ", ", dtype, ")\n");
1096   }
1097 }
1098 
AddEagerAttrs(const string & indentation)1099 void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
1100   // Compute eager attrs
1101   if (op_def_.attr_size() > 0) {
1102     string attr_values;
1103     for (int i = 0; i < op_def_.attr_size(); ++i) {
1104       if (i > 0) strings::StrAppend(&attr_values, ", ");
1105       const auto& attr_name(op_def_.attr(i).name());
1106       strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
1107                          attr_expressions_[attr_name]);
1108     }
1109     strings::StrAppend(&attr_values, ")");
1110     strings::StrAppend(
1111         &result_,
1112         WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
1113                  kRightMargin),
1114         "\n");
1115   } else {
1116     strings::StrAppend(&result_, indentation, "_attrs = None\n");
1117   }
1118 }
1119 
AddEagerExecute(const string & indentation,const string & num_outputs_expr)1120 void GenEagerPythonOp::AddEagerExecute(const string& indentation,
1121                                        const string& num_outputs_expr) {
1122   const string return_prefix =
1123       strings::StrCat(indentation, "_result = _execute.execute(");
1124   const string return_args = strings::StrCat(
1125       "b\"", op_def_.name(), "\", ", num_outputs_expr,
1126       ", inputs=_inputs_flat, attrs=_attrs, ctx=ctx, name=name)");
1127   strings::StrAppend(&result_,
1128                      // Wrap the arguments, and indent to the (.
1129                      WordWrap(return_prefix, return_args, kRightMargin), "\n");
1130 }
1131 
AddDispatch(const string & prefix)1132 void GenEagerPythonOp::AddDispatch(const string& prefix) {
1133   if (api_def_.visibility() != ApiDef::VISIBLE) return;
1134 
1135   strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
1136   strings::StrAppend(&result_, prefix, "  result = _dispatch.dispatch(\n");
1137   AddBodyNoReturn(strings::StrCat(prefix, "        ", function_name_,
1138                                   ", "
1139                                   "(), dict("));
1140   strings::StrAppend(&result_, prefix, "      )\n");
1141   strings::StrAppend(&result_, prefix,
1142                      "  if result is not "
1143                      "_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
1144   strings::StrAppend(&result_, prefix, "    return result\n");
1145   strings::StrAppend(&result_, prefix, "  raise\n");
1146 }
1147 
AddRawOpExport(const string & parameters)1148 void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
1149   // Example:
1150   //
1151   // Identity = tf_export("raw_ops.Identity")(_ops._to_raw_op(identity))
1152   const string raw_function_name =
1153       python_op_gen_internal::AvoidPythonReserved(op_def_.name());
1154   strings::StrAppend(&result_, raw_function_name, " = tf_export(\"raw_ops.",
1155                      raw_function_name, "\")", "(_ops.to_raw_op(",
1156                      function_name_, "))\n");
1157 }
1158 
GetPythonOpsImpl(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name="",const std::unordered_set<string> type_annotate_ops={})1159 string GetPythonOpsImpl(
1160     const OpList& ops, const ApiDefMap& api_defs,
1161     const std::vector<string>& hidden_ops, const string& source_file_name = "",
1162     const std::unordered_set<string> type_annotate_ops = {}) {
1163   string result;
1164   // Header
1165   // TODO(josh11b): Mention the library for which wrappers are being generated.
1166   strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
1167 
1168 This file is MACHINE GENERATED! Do not edit.
1169 )");
1170 
1171   // Mention the original source file so someone tracing back through
1172   // generated Python code will know where to look next.
1173   if (!source_file_name.empty()) {
1174     strings::StrAppend(&result, "Original C++ source file: ");
1175     strings::StrAppend(&result, source_file_name);
1176     strings::StrAppend(&result, "\n");
1177   }
1178 
1179   strings::StrAppend(&result, R"("""
1180 
1181 import collections
1182 
1183 from tensorflow.python import pywrap_tfe as pywrap_tfe
1184 from tensorflow.python.eager import context as _context
1185 from tensorflow.python.eager import core as _core
1186 from tensorflow.python.eager import execute as _execute
1187 from tensorflow.python.framework import dtypes as _dtypes
1188 
1189 from tensorflow.python.framework import op_def_registry as _op_def_registry
1190 from tensorflow.python.framework import ops as _ops
1191 from tensorflow.python.framework import op_def_library as _op_def_library
1192 from tensorflow.python.util.deprecation import deprecated_endpoints
1193 from tensorflow.python.util import dispatch as _dispatch
1194 from tensorflow.python.util.tf_export import tf_export
1195 
1196 from typing import TypeVar
1197 )");
1198 
1199   for (const auto& op_def : ops.op()) {
1200     const auto* api_def = api_defs.GetApiDef(op_def.name());
1201 
1202     if (api_def->visibility() == ApiDef::SKIP) {
1203       continue;
1204     }
1205     // An op is hidden if either its ApiDef visibility is HIDDEN
1206     // or it is in the hidden_ops list.
1207     bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
1208     bool hidden_by_api_def = is_hidden;
1209     if (!is_hidden) {
1210       for (const string& hidden : hidden_ops) {
1211         if (op_def.name() == hidden) {
1212           is_hidden = true;
1213           break;
1214         }
1215       }
1216     }
1217 
1218     string function_name;
1219     python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
1220                                                     &function_name);
1221     bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name);
1222 
1223     // Prefix an op with underscore if the op is listed in hidden_ops or
1224     // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix.
1225     // Do not add underscores to ops set to HIDDEN in ApiDef otherwise.
1226     // TODO(annarev): don't prefix with underscores even if op is in hidden_ops.
1227     if (is_hidden) {
1228       if (!hidden_by_api_def || is_reserved ||
1229           python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) {
1230         function_name = strings::StrCat("_", function_name);
1231       }
1232     } else if (is_reserved) {
1233       // When users create custom python wrappers, they may link in the
1234       // default op registry by accident, and because they can't
1235       // enumerate all 'hidden' symbols, this guard is to prevent
1236       // instantiating a python reserved word in their wrapper.
1237       continue;
1238     }
1239 
1240     auto iter = type_annotate_ops.find(op_def.name());
1241     bool add_type_annotations = iter != type_annotate_ops.end();
1242 
1243     strings::StrAppend(&result,
1244                        GetEagerPythonOp(op_def, *api_def, function_name,
1245                                         add_type_annotations));
1246   }
1247 
1248   return result;
1249 }
1250 
1251 }  // namespace
1252 
GetPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name,const std::unordered_set<string> type_annotate_ops)1253 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1254                     const std::vector<string>& hidden_ops,
1255                     const string& source_file_name,
1256                     const std::unordered_set<string> type_annotate_ops) {
1257   return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1258                           type_annotate_ops);
1259 }
1260 
PrintPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name,const std::unordered_set<string> type_annotate_ops)1261 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1262                     const std::vector<string>& hidden_ops,
1263                     const string& source_file_name,
1264                     const std::unordered_set<string> type_annotate_ops) {
1265   printf("%s", GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1266                                 type_annotate_ops)
1267                    .c_str());
1268 }
1269 
GetPythonWrappers(const char * op_list_buf,size_t op_list_len)1270 string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
1271   OpList ops;
1272   ops.ParseFromArray(op_list_buf, op_list_len);
1273 
1274   ApiDefMap api_def_map(ops);
1275   return GetPythonOpsImpl(ops, api_def_map, {});
1276 }
1277 
GetArgAnnotation(const OpDef::ArgDef & arg,const std::unordered_map<string,string> & type_annotations)1278 string GetArgAnnotation(
1279     const OpDef::ArgDef& arg,
1280     const std::unordered_map<string, string>& type_annotations) {
1281   if (!arg.type_attr().empty()) {
1282     // Get the correct TypeVar if arg maps to an attr
1283     return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
1284   } else {
1285     // Get the dtype of the Tensor
1286     const string py_dtype =
1287         python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1288     return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
1289   }
1290 
1291   return "Any";
1292 }
1293 
1294 }  // namespace tensorflow
1295