• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/python_op_gen.h"
16 
17 #include <stdio.h>
18 #include <sstream>
19 #include <unordered_map>
20 #include "tensorflow/core/framework/api_def.pb.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_def.pb_text.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/framework/op_def_util.h"
26 #include "tensorflow/core/framework/op_gen_lib.h"
27 #include "tensorflow/core/framework/tensor.pb_text.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/lib/gtl/stl_util.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/lib/strings/stringprintf.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/python/framework/python_op_gen_internal.h"
39 
40 namespace tensorflow {
41 namespace {
42 
43 const int kRightMargin = 78;
44 
45 constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
46 
AttrVarName(const string & attr_name,std::unordered_map<string,string> * attr_expressions)47 string AttrVarName(const string& attr_name,
48                    std::unordered_map<string, string>* attr_expressions) {
49   const string var = strings::StrCat("_attr_", attr_name);
50   if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
51   return var;
52 }
53 
AddInferredAttr(const string & indentation,const string & attr_name,const string & value_expression,string * result,std::unordered_map<string,string> * attr_expressions)54 void AddInferredAttr(const string& indentation, const string& attr_name,
55                      const string& value_expression, string* result,
56                      std::unordered_map<string, string>* attr_expressions) {
57   strings::StrAppend(result, indentation,
58                      AttrVarName(attr_name, attr_expressions), " = ",
59                      value_expression, "\n");
60 }
61 
VectorToTuple(const std::vector<string> & l)62 string VectorToTuple(const std::vector<string>& l) {
63   if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
64   string ret = "(";
65   for (int i = 0; i < l.size(); ++i) {
66     if (i > 0) {
67       strings::StrAppend(&ret, ", ");
68     }
69     strings::StrAppend(&ret, l[i]);
70   }
71   strings::StrAppend(&ret, ")");
72   return ret;
73 }
74 
Unflatten(const string & prefix,const std::vector<string> & output_sizes,const string & var,string * result)75 void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
76                const string& var, string* result) {
77   for (int i = 0; i < output_sizes.size(); ++i) {
78     if (!output_sizes[i].empty()) {
79       strings::StrAppend(result, prefix, var, " = ");
80       if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
81       if (i + 1 < output_sizes.size()) {
82         // Special case i == 0 to avoid "0 +" in the generated code.
83         if (i == 0) {
84           strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
85                              var, "[", output_sizes[i], ":]");
86         } else {
87           strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
88                              output_sizes[i], "]] + ", var, "[", i, " + ",
89                              output_sizes[i], ":]");
90         }
91       } else {
92         strings::StrAppend(result, "[", var, "[", i, ":]]");
93       }
94       strings::StrAppend(result, "\n");
95     }
96   }
97 }
98 
TensorPBString(const TensorProto & pb)99 string TensorPBString(const TensorProto& pb) {
100   // Note: This gets used in the argument list, and so must survive naive
101   // word wrapping.
102   return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
103 }
104 
105 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
106  public:
GenEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name)107   GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
108                    const string& function_name)
109       : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
110     op_name_ = function_name_;
111     str_util::ConsumePrefix(&op_name_, "_");
112   }
~GenEagerPythonOp()113   ~GenEagerPythonOp() override {}
114 
115   string Code() override;
116 
117  protected:
118   void HandleGraphMode(const string& function_setup);
119 
120   string GetEagerNotAllowedError();
121   void ExpectListArg(const string& indentation, const string& arg_name,
122                      string* output);
123   bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
124   void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
125                                        string* num_outputs_expr);
126 
127   void AddEagerFunctionTeardown(const string& indentation,
128                                 const std::vector<string>& output_sizes,
129                                 bool execute_record_gradient);
130 
131   bool AddEagerFastPathAndGraphCode(const string& parameters,
132                                     const std::vector<string>& output_sizes,
133                                     const string& eager_not_allowed_error);
134   bool AddEagerFallbackCode(const string& parameters,
135                             const std::vector<string>& output_sizes,
136                             const string& num_outputs_expr,
137                             const string& eager_not_allowed_error);
138   void AddEagerFastPathExecute();
139 
140   void AddEagerInferredAttrs(const string& indentation);
141   void AddEagerInputCasts(const string& indentation);
142   void AddEagerAttrs(const string& indentation);
143   void AddEagerExecute(const string& indentation,
144                        const string& num_outputs_expr);
145   void AddDispatch(const string& prefix);
146 
147   void AddRawOpExport(const string& parameters);
148 
AddAttrForArg(const string & attr,int arg_index)149   void AddAttrForArg(const string& attr, int arg_index) {
150     gtl::InsertIfNotPresent(&inferred_attrs_, attr,
151                             op_def_.input_arg(arg_index).name());
152     auto iter = attr_to_args_.find(attr);
153     if (iter == attr_to_args_.end()) {
154       attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
155     } else {
156       iter->second.push_back(arg_index);
157     }
158   }
159 
160   // Returns a string expression representing a flattened list of all
161   // the inputs given by `*input_indices` (or all inputs if
162   // `input_indices` is nullptr).  `*output_sizes` can be used to unflatten.
163   string FlattenInputs(const std::vector<int>* input_indices,
164                        std::vector<string>* output_sizes) const;
165 
166   StringPiece op_name_;
167   typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
168   AttrToArgMap attr_to_args_;
169   std::unordered_map<string, string> attr_expressions_;
170   // This has all the input args followed by those attrs that don't have
171   // defaults.
172   std::vector<python_op_gen_internal::ParamNames> params_no_default_;
173   // The parameters with defaults (these have to be listed after those without).
174   // No input args are included, just attrs.
175   std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
176       params_with_default_;
177 };
178 
GetEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name)179 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
180                         const string& function_name) {
181   return GenEagerPythonOp(op_def, api_def, function_name).Code();
182 }
183 
FlattenInputs(const std::vector<int> * input_indices,std::vector<string> * output_sizes) const184 string GenEagerPythonOp::FlattenInputs(
185     const std::vector<int>* input_indices,
186     std::vector<string>* output_sizes) const {
187   string inputs;
188   enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
189   const int n = input_indices != nullptr ? input_indices->size()
190                                          : op_def_.input_arg_size();
191   for (int j = 0; j < n; ++j) {
192     const int i = input_indices ? (*input_indices)[j] : j;
193     const auto& arg(op_def_.input_arg(i));
194     const bool is_list =
195         !arg.type_list_attr().empty() || !arg.number_attr().empty();
196     if (is_list) {
197       if (inputs_state == WAS_SOLO_INPUT) {
198         strings::StrAppend(&inputs, "] + ");
199       } else if (inputs_state == WAS_LIST_INPUT) {
200         strings::StrAppend(&inputs, " + ");
201       }
202       strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
203       inputs_state = WAS_LIST_INPUT;
204       if (output_sizes != nullptr) {
205         if (!arg.number_attr().empty()) {
206           output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
207         } else {
208           output_sizes->emplace_back(
209               strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
210         }
211       }
212     } else {
213       if (inputs_state == WAS_SOLO_INPUT) {
214         strings::StrAppend(&inputs, ", ");
215       } else if (inputs_state == WAS_LIST_INPUT) {
216         strings::StrAppend(&inputs, " + [");
217       } else {
218         strings::StrAppend(&inputs, "[");
219       }
220       strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
221       inputs_state = WAS_SOLO_INPUT;
222       if (output_sizes != nullptr) output_sizes->emplace_back();
223     }
224   }
225   if (inputs_state == STARTING) return "[]";
226   if (inputs_state == WAS_SOLO_INPUT) {
227     strings::StrAppend(&inputs, "]");
228   }
229   return inputs;
230 }
231 
Code()232 string GenEagerPythonOp::Code() {
233   if (api_def_.visibility() == ApiDef::SKIP) {
234     return "";
235   }
236 
237   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
238     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
239     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
240     params_no_default_.emplace_back(api_def_arg.name(),
241                                     api_def_arg.rename_to());
242     if (!arg.type_attr().empty()) {
243       AddAttrForArg(arg.type_attr(), i);
244     } else if (!arg.type_list_attr().empty()) {
245       AddAttrForArg(arg.type_list_attr(), i);
246     }
247     if (!arg.number_attr().empty()) {
248       AddAttrForArg(arg.number_attr(), i);
249     }
250   }
251   for (int i = 0; i < op_def_.attr_size(); ++i) {
252     const auto& attr(op_def_.attr(i));
253     const auto& api_def_attr(api_def_.attr(i));
254     // Do not add inferred attrs to the Python function signature.
255     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
256       if (api_def_attr.has_default_value()) {
257         if (attr.type() == "tensor") {
258           params_with_default_.emplace_back(
259               python_op_gen_internal::ParamNames(api_def_attr.name(),
260                                                  api_def_attr.rename_to()),
261               strings::StrCat(
262                   "_execute.make_tensor(",
263                   TensorPBString(api_def_attr.default_value().tensor()), ", \"",
264                   api_def_attr.rename_to(), "\")"));
265         } else if (attr.type() == "list(tensor)") {
266           std::vector<string> pbtxt;
267           for (const auto& pb : api_def_attr.default_value().list().tensor()) {
268             pbtxt.emplace_back(TensorPBString(pb));
269           }
270           params_with_default_.emplace_back(
271               python_op_gen_internal::ParamNames(api_def_attr.name(),
272                                                  api_def_attr.rename_to()),
273               strings::StrCat("[_execute.make_tensor(_pb, \"",
274                               api_def_attr.rename_to(), "\") for _pb in ",
275                               VectorToTuple(pbtxt), "]"));
276         } else {
277           params_with_default_.emplace_back(
278               python_op_gen_internal::ParamNames(api_def_attr.name(),
279                                                  api_def_attr.rename_to()),
280               python_op_gen_internal::AttrValueToPython(
281                   attr.type(), api_def_attr.default_value(), "_dtypes."));
282         }
283       } else {
284         params_no_default_.emplace_back(api_def_attr.name(),
285                                         api_def_attr.rename_to());
286       }
287     }
288   }
289 
290   // Save the list of attr parameters (attrs that won't be inferred),
291   // those with defaults go at the end.
292   // Get the attrs in the order we want by taking the attrs without defaults
293   // from the end of params_no_default_, and adding params_no_default_.
294   attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
295                  params_with_default_.size());
296   for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) {
297     attrs_.push_back(params_no_default_[i].GetName());
298   }
299   for (const auto& p : params_with_default_) {
300     attrs_.push_back(p.first.GetName());
301   }
302 
303   // TODO(slebedev): call AvoidPythonReserved on each param?
304   param_names_.reserve(params_no_default_.size() + params_with_default_.size());
305   param_names_.insert(param_names_.begin(), params_no_default_.begin(),
306                       params_no_default_.end());
307   for (const auto& param_and_default : params_with_default_) {
308     param_names_.push_back(param_and_default.first);
309   }
310 
311   string parameters;
312   for (const auto& param : params_no_default_) {
313     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
314     strings::StrAppend(&parameters, param.GetRenameTo());
315   }
316   for (const auto& param_and_default : params_with_default_) {
317     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
318     strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
319                        param_and_default.second);
320   }
321   strings::StrAppend(&parameters, parameters.empty() ? "" : ", ", "name=None");
322 
323   // Add attr_expressions_ for attrs that are params.
324   for (int i = 0; i < attrs_.size(); ++i) {
325     const string& attr_name = attrs_[i];
326     const string& attr_api_name =
327         param_names_[i + op_def_.input_arg_size()].GetRenameTo();
328     attr_expressions_[attr_name] = attr_api_name;
329   }
330   // Add attr_expressions_ for attrs that are inferred.
331   for (int i = 0; i < op_def_.attr_size(); ++i) {
332     const auto& attr(op_def_.attr(i));
333     if (attr.type() == "int") {
334       auto arg_list = attr_to_args_.find(attr.name());
335       if (arg_list != attr_to_args_.end()) {
336         AttrVarName(attr.name(), &attr_expressions_);
337       }
338     }
339   }
340 
341   string num_outputs_expr;
342   std::vector<string> output_sizes(num_outs_);
343   GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
344 
345   string eager_not_allowed_error = GetEagerNotAllowedError();
346 
347   if (!AddEagerFastPathAndGraphCode(parameters, output_sizes,
348                                     eager_not_allowed_error)) {
349     return result_;
350   }
351 
352   if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
353                             eager_not_allowed_error)) {
354     return result_;
355   }
356 
357   return prelude_ + result_;
358 }
359 
HandleGraphMode(const string & function_setup)360 void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
361   strings::StrAppend(&result_, "  # Add nodes to the TensorFlow graph.\n");
362   strings::StrAppend(&result_, function_setup);
363   if (api_def_.visibility() == ApiDef::VISIBLE) {
364     strings::StrAppend(&result_, "  try:\n  ");
365   }
366   strings::StrAppend(&result_, "  _, _, _op = _op_def_lib._apply_op_helper(\n");
367   AddBodyNoReturn(strings::StrCat("        \"", op_def_.name(), "\", "));
368   AddDispatch("  ");
369 
370   if (num_outs_ > 0) {
371     strings::StrAppend(&result_, "  _result = _op.outputs[:]\n");
372     // Special case handling for stateful op with single list output
373     // that might be empty.
374     if (num_outs_ == 1 && op_def_.is_stateful() &&
375         (!op_def_.output_arg(0).number_attr().empty() ||
376          !op_def_.output_arg(0).type_list_attr().empty())) {
377       // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
378       // a constraint indicating that this can never be empty.
379       strings::StrAppend(&result_,
380                          "  if not _result:\n"
381                          "    return _op\n");
382     }
383     strings::StrAppend(&result_, "  _inputs_flat = _op.inputs\n");
384 
385     // Compute graph-mode attrs.
386     if (op_def_.attr_size() > 0) {
387       string attr_values;
388       for (int i = 0; i < op_def_.attr_size(); ++i) {
389         if (i > 0) strings::StrAppend(&attr_values, ", ");
390         const auto& attr_name(op_def_.attr(i).name());
391         strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
392                            attr_name, "\")");
393       }
394       strings::StrAppend(&attr_values, ")");
395       strings::StrAppend(
396           &result_, WordWrap("  _attrs = (", attr_values, kRightMargin), "\n");
397     } else {
398       strings::StrAppend(&result_, "  _attrs = None\n");
399     }
400   } else {
401     strings::StrAppend(&result_, "  return _op\n");
402   }
403 }
404 
GetEagerNotAllowedError()405 string GenEagerPythonOp::GetEagerNotAllowedError() {
406   bool eager_allowed = true;
407   string ref_arg;
408   for (int i = 0; i < op_def_.input_arg_size(); ++i) {
409     const auto& arg = op_def_.input_arg(i);
410     if (arg.is_ref()) {
411       eager_allowed = false;
412       DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
413       ref_arg = api_def_.in_arg(i).rename_to();
414     }
415   }
416   for (int i = 0; i < op_def_.output_arg_size(); ++i) {
417     const auto& arg = op_def_.output_arg(i);
418     if (arg.is_ref()) {
419       eager_allowed = false;
420       DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
421       ref_arg = api_def_.out_arg(i).rename_to();
422     }
423   }
424 
425   if (eager_allowed) return "";
426 
427   return strings::StrCat("raise RuntimeError(\"", op_name_,
428                          " op does not support eager execution. ", "Arg '",
429                          ref_arg, "' is a ref.\")\n");
430 }
431 
ExpectListArg(const string & indentation,const string & arg_name,string * output)432 void GenEagerPythonOp::ExpectListArg(const string& indentation,
433                                      const string& arg_name, string* output) {
434   strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
435                      ", (list, tuple)):\n", indentation, "  raise TypeError(\n",
436                      indentation, "      \"Expected list for '", arg_name,
437                      "' argument to \"\n", indentation, "      \"'", op_name_,
438                      "' Op, not %r.\" % ", arg_name, ")\n");
439 }
440 
GetEagerFunctionSetup(const string & indentation,string * function_setup)441 bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
442                                              string* function_setup) {
443   // Validate list inputs, infer length attrs.
444   for (int i = 0; i < op_def_.attr_size(); ++i) {
445     const auto& attr(op_def_.attr(i));
446     if (attr.type() == "int") {
447       auto arg_list = attr_to_args_.find(attr.name());
448       if (arg_list != attr_to_args_.end()) {
449         // Inferred int attrs are the lengths of inputs. Validate those
450         // inputs are lists and have the same length.
451         for (auto iter = arg_list->second.begin();
452              iter != arg_list->second.end(); ++iter) {
453           const string& arg_api_name = param_names_[*iter].GetRenameTo();
454           ExpectListArg(indentation, arg_api_name, function_setup);
455           if (iter == arg_list->second.begin()) {
456             AddInferredAttr(indentation, attr.name(),
457                             strings::StrCat("len(", arg_api_name, ")"),
458                             function_setup, &attr_expressions_);
459           } else {
460             const auto& attr_var = attr_expressions_[attr.name()];
461             strings::StrAppend(
462                 function_setup, indentation, "if len(", arg_api_name,
463                 ") != ", attr_var, ":\n", indentation, "  raise ValueError(\n",
464                 indentation, "      \"List argument '", arg_api_name, "' to '",
465                 op_name_, "' Op with length %d \"\n", indentation,
466                 "      \"must match length %d of argument '",
467                 inferred_attrs_[attr.name()], "'.\" %\n", indentation,
468                 "      (len(", arg_api_name, "), ", attr_var, "))\n");
469           }
470         }
471       }
472     }
473   }
474 
475   for (int i = 0; i < attrs_.size(); ++i) {
476     const string& attr_name = attrs_[i];
477     const auto& param = param_names_[i + op_def_.input_arg_size()];
478     const auto& attr = *FindAttr(attr_name, op_def_);
479     const string& attr_api_name = param.GetRenameTo();
480     StringPiece attr_type = attr.type();
481     attr_expressions_[attr_name] = attr_api_name;
482     const int default_index = i - (attrs_.size() - params_with_default_.size());
483     if (default_index >= 0) {
484       const string& default_value = params_with_default_[default_index].second;
485       strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
486                          " is None:\n");
487       strings::StrAppend(function_setup, indentation, "  ", attr_api_name,
488                          " = ", default_value, "\n");
489     }
490     if (str_util::StartsWith(attr_type, "list(")) {
491       ExpectListArg(indentation, attr_api_name, function_setup);
492     }
493 
494     if (attr_type == "string") {
495       strings::StrAppend(function_setup, indentation, attr_api_name,
496                          " = _execute.make_str(", attr_api_name, ", \"",
497                          attr_api_name, "\")\n");
498     } else if (attr_type == "list(string)") {
499       strings::StrAppend(function_setup, indentation, attr_api_name,
500                          " = [_execute.make_str(_s, \"", attr_api_name,
501                          "\") for _s in ", attr_api_name, "]\n");
502     } else if (attr_type == "int") {
503       strings::StrAppend(function_setup, indentation, attr_api_name,
504                          " = _execute.make_int(", attr_api_name, ", \"",
505                          attr_api_name, "\")\n");
506     } else if (attr_type == "list(int)") {
507       strings::StrAppend(function_setup, indentation, attr_api_name,
508                          " = [_execute.make_int(_i, \"", attr_api_name,
509                          "\") for _i in ", attr_api_name, "]\n");
510     } else if (attr_type == "float") {
511       strings::StrAppend(function_setup, indentation, attr_api_name,
512                          " = _execute.make_float(", attr_api_name, ", \"",
513                          attr_api_name, "\")\n");
514     } else if (attr_type == "list(float)") {
515       strings::StrAppend(function_setup, indentation, attr_api_name,
516                          " = [_execute.make_float(_f, \"", attr_api_name,
517                          "\") for _f in ", attr_api_name, "]\n");
518     } else if (attr_type == "bool") {
519       strings::StrAppend(function_setup, indentation, attr_api_name,
520                          " = _execute.make_bool(", attr_api_name, ", \"",
521                          attr_api_name, "\")\n");
522     } else if (attr_type == "list(bool)") {
523       strings::StrAppend(function_setup, indentation, attr_api_name,
524                          " = [_execute.make_bool(_b, \"", attr_api_name,
525                          "\") for _b in ", attr_api_name, "]\n");
526     } else if (attr_type == "type") {
527       strings::StrAppend(function_setup, indentation, attr_api_name,
528                          " = _execute.make_type(", attr_api_name, ", \"",
529                          attr_api_name, "\")\n");
530     } else if (attr_type == "list(type)") {
531       strings::StrAppend(function_setup, indentation, attr_api_name,
532                          " = [_execute.make_type(_t, \"", attr_api_name,
533                          "\") for _t in ", attr_api_name, "]\n");
534     } else if (attr_type == "shape") {
535       strings::StrAppend(function_setup, indentation, attr_api_name,
536                          " = _execute.make_shape(", attr_api_name, ", \"",
537                          attr_api_name, "\")\n");
538     } else if (attr_type == "list(shape)") {
539       strings::StrAppend(function_setup, indentation, attr_api_name,
540                          " = [_execute.make_shape(_s, \"", attr_api_name,
541                          "\") for _s in ", attr_api_name, "]\n");
542     } else if (attr_type == "tensor") {
543       strings::StrAppend(function_setup, indentation, attr_api_name,
544                          " = _execute.make_tensor(", attr_api_name, ", \"",
545                          attr_api_name, "\")\n");
546     } else if (attr_type == "list(tensor)") {
547       strings::StrAppend(function_setup, indentation, attr_api_name,
548                          " = [_execute.make_tensor(_t, \"", attr_api_name,
549                          "\") for _t in ", attr_api_name, "]\n");
550     } else if (attr_type != "func" && attr_type != "list(func)") {
551       *function_setup =
552           strings::StrCat("# No definition for ", function_name_,
553                           " since we don't support attrs with type\n"
554                           "# '",
555                           attr_type, "' right now.\n\n");
556       return false;
557     }
558   }
559   return true;
560 }
561 
562 // If output i is list output, output_sizes[i] will be set to a
563 // string with the python expression that will evaluate to its
564 // length. output_sizes[i] is empty for non-list outputs.
GetOutputSizesAndNumOutputsExpr(std::vector<string> * output_sizes,string * num_outputs_expr)565 void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
566     std::vector<string>* output_sizes, string* num_outputs_expr) {
567   // Expression representing the number of outputs.
568   int num_fixed_outputs = 0;
569   for (int i = 0; i < num_outs_; ++i) {
570     const auto& arg(op_def_.output_arg(i));
571     if (!arg.number_attr().empty()) {
572       if (!num_outputs_expr->empty()) {
573         strings::StrAppend(num_outputs_expr, " + ");
574       }
575       (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
576       strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
577     } else if (!arg.type_list_attr().empty()) {
578       if (!num_outputs_expr->empty()) {
579         strings::StrAppend(num_outputs_expr, " + ");
580       }
581       // Have to be careful to use an expression that works in both
582       // graph and eager paths here.
583       const auto iter = inferred_attrs_.find(arg.type_list_attr());
584       if (iter == inferred_attrs_.end()) {
585         (*output_sizes)[i] = strings::StrCat(
586             "len(", attr_expressions_[arg.type_list_attr()], ")");
587       } else {
588         (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
589       }
590       strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
591     } else {
592       ++num_fixed_outputs;
593     }
594   }
595   if (num_fixed_outputs > 0) {
596     if (!num_outputs_expr->empty()) {
597       strings::StrAppend(num_outputs_expr, " + ");
598     }
599     strings::StrAppend(num_outputs_expr, num_fixed_outputs);
600   } else if (num_outputs_expr->empty()) {
601     *num_outputs_expr = "0";
602   }
603 }
604 
AddEagerFunctionTeardown(const string & indentation,const std::vector<string> & output_sizes,bool execute_record_gradient)605 void GenEagerPythonOp::AddEagerFunctionTeardown(
606     const string& indentation, const std::vector<string>& output_sizes,
607     bool execute_record_gradient) {
608   if (num_outs_ > 0) {
609     if (execute_record_gradient) {
610       strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n",
611                          "      \"", op_def_.name(),
612                          "\", _inputs_flat, _attrs, _result, name)\n");
613     }
614     if (num_outs_ == 1 && !output_sizes[0].empty()) {
615       // Single list result.
616     } else if (num_outs_ == 1) {
617       // Execute returns a single-element list which we need to destructure.
618       strings::StrAppend(&result_, indentation, "_result, = _result\n");
619     } else {
620       // Have multiple outputs, so we will need to reformat the return
621       // value of execute() to be a list with one entry per op output
622       // (that entry will be a list of tensors if that output is of list
623       // type).
624       // For list outputs, convert the right subrange of _result into a list.
625       Unflatten(indentation, output_sizes, "_result", &result_);
626       // Convert to a named tuple.
627       strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(),
628                          "Output._make(_result)\n");
629     }
630   } else {
631     strings::StrAppend(&result_, indentation, "_result = None\n");
632   }
633   strings::StrAppend(&result_, indentation, "return _result\n\n");
634 }
635 
AddEagerFastPathAndGraphCode(const string & parameters,const std::vector<string> & output_sizes,const string & eager_not_allowed_error)636 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
637     const string& parameters, const std::vector<string>& output_sizes,
638     const string& eager_not_allowed_error) {
639   if (api_def_.visibility() == ApiDef::VISIBLE) {
640     strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
641   }
642 
643   AddExport();
644   AddDefLine(function_name_, parameters);
645   AddDocStringDescription();
646   AddDocStringArgs();
647   AddDocStringInputs();
648   AddDocStringAttrs();
649   AddDocStringNameArg();
650   AddOutputGlobals();  // Added to prelude_
651   AddDocStringOutputs();
652   strings::StrAppend(&result_, "  \"\"\"\n");
653 
654   strings::StrAppend(
655       &result_,
656       "  _ctx = _context._context or _context.context()\n"
657       "  if _ctx is not None and _ctx._thread_local_data.is_eager:",
658       "\n");
659   if (eager_not_allowed_error.empty()) {
660     AddEagerFastPathExecute();
661   } else {
662     strings::StrAppend(&result_, "    ", eager_not_allowed_error);
663   }
664 
665   // Handle graph-mode case
666   string function_setup;
667   if (!GetEagerFunctionSetup("  ", &function_setup)) {
668     result_ = function_setup;
669     return false;
670   }
671   HandleGraphMode(function_setup);
672   AddEagerFunctionTeardown("  ", output_sizes,
673                            true /* execute_record_gradient */);
674 
675   AddRawOpExport(parameters);
676   strings::StrAppend(&result_, "\n\n");
677   return true;
678 }
679 
AddEagerFallbackCode(const string & parameters,const std::vector<string> & output_sizes,const string & num_outputs_expr,const string & eager_not_allowed_error)680 bool GenEagerPythonOp::AddEagerFallbackCode(
681     const string& parameters, const std::vector<string>& output_sizes,
682     const string& num_outputs_expr, const string& eager_not_allowed_error) {
683   AddDefLine(
684       strings::StrCat(function_name_, kEagerFallbackSuffix),
685       strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx=None"));
686 
687   if (!eager_not_allowed_error.empty()) {
688     strings::StrAppend(&result_, "  ", eager_not_allowed_error);
689     return true;
690   }
691 
692   strings::StrAppend(
693       &result_, "  r\"\"\"This is the slowpath function for Eager mode.\n");
694   strings::StrAppend(&result_, "  This is for function ", function_name_,
695                      "\n  \"\"\"\n");
696 
697   strings::StrAppend(&result_, "  _ctx = ctx if ctx else _context.context()\n");
698 
699   string function_setup;
700   if (!GetEagerFunctionSetup("  ", &function_setup)) {
701     result_ = function_setup;
702     return false;
703   }
704   strings::StrAppend(&result_, function_setup);
705 
706   AddEagerInferredAttrs("  ");
707   AddEagerInputCasts("  ");
708   strings::StrAppend(
709       &result_, "  _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
710   AddEagerAttrs("  ");
711   AddEagerExecute("  ", num_outputs_expr);
712 
713   AddEagerFunctionTeardown("  ", output_sizes,
714                            true /* execute_record_gradient */);
715 
716   return true;
717 }
718 
AddEagerFastPathExecute()719 void GenEagerPythonOp::AddEagerFastPathExecute() {
720   string fastpath_execute_params = strings::StrCat(
721       "_ctx._context_handle, _ctx._thread_local_data.device_name, \"",
722       op_def_.name(), "\", ", "name, _ctx._post_execution_callbacks");
723   string fallback_params;
724 
725   for (int i = 0; i < api_def_.in_arg_size(); i++) {
726     const string param_name = param_names_[i].GetRenameTo();
727     strings::StrAppend(&fastpath_execute_params, ", ", param_name);
728     if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
729     strings::StrAppend(&fallback_params, param_name);
730   }
731 
732   for (const auto& attr : api_def_.attr()) {
733     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
734       strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
735                          attr.rename_to());
736 
737       if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
738       strings::StrAppend(&fallback_params, attr.rename_to(), "=",
739                          attr.rename_to());
740     }
741   }
742 
743   if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
744   strings::StrAppend(&fallback_params, "name=name");
745 
746   strings::StrAppend(&result_, "    try:\n");
747   strings::StrAppend(
748       &result_, "      ",
749       "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
750       WordWrap(strings::StrCat("        "),
751                strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
752       "\n");
753 
754   if (op_def_.output_arg_size() > 1) {
755     const string output_tuple_name =
756         strings::StrCat("_", op_def_.name(), "Output");
757     strings::StrAppend(&result_, "      ", "_result = ", output_tuple_name,
758                        "._make(_result)\n");
759   }
760   strings::StrAppend(&result_, "      ", "return _result\n");
761 
762   // Handle fallback.
763   if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
764   strings::StrAppend(&fallback_params, "ctx=_ctx");
765   strings::StrAppend(&result_, "    ", "except _core._FallbackException:\n");
766   strings::StrAppend(&result_, "      try:\n");
767   strings::StrAppend(
768       &result_, "        ", "return ", function_name_, kEagerFallbackSuffix,
769       "(\n",
770       WordWrap(strings::StrCat("            "),
771                strings::StrCat(fallback_params, ")"), kRightMargin),
772       "\n");
773   strings::StrAppend(&result_, "      except _core._SymbolicException:\n");
774   strings::StrAppend(&result_,
775                      "        pass  # Add nodes to the TensorFlow graph.\n");
776   AddDispatch("      ");
777 
778   // Any errors thrown from execute need to be unwrapped from
779   // _NotOkStatusException.
780   strings::StrAppend(&result_, "    ",
781                      "except _core._NotOkStatusException as e:\n");
782   strings::StrAppend(&result_, "      ", "if name is not None:\n");
783   strings::StrAppend(&result_, "        ",
784                      "message = e.message + \" name: \" + name\n");
785   strings::StrAppend(&result_, "      ", "else:\n");
786   strings::StrAppend(&result_, "        ", "message = e.message\n");
787   strings::StrAppend(
788       &result_, "      ",
789       "_six.raise_from(_core._status_to_exception(e.code, message), None)\n");
790 }
791 
AddEagerInferredAttrs(const string & indentation)792 void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
793   // Figure out values for inferred attrs, and cast to eager tensors.
794   for (int i = 0; i < op_def_.attr_size(); ++i) {
795     const auto& attr(op_def_.attr(i));
796     const auto& api_def_attr(api_def_.attr(i));
797     auto arg_list = attr_to_args_.find(attr.name());
798     if (arg_list != attr_to_args_.end()) {
799       if (attr.type() == "type") {
800         std::vector<string> output_sizes;
801         const string flattened =
802             FlattenInputs(&arg_list->second, &output_sizes);
803         string conversion = strings::StrCat("_execute.args_to_matching_eager(",
804                                             flattened, ", _ctx");
805         if (attr.has_default_value()) {
806           strings::StrAppend(
807               &conversion, ", ",
808               python_op_gen_internal::AttrValueToPython(
809                   attr.type(), api_def_attr.default_value(), "_dtypes."));
810         }
811         strings::StrAppend(&conversion, ")");
812         const string var_name = AttrVarName(attr.name(), &attr_expressions_);
813         if (output_sizes.size() == 1) {
814           // Avoid creating a temporary variable in the case where
815           // we can easily assign to the right value directly.
816           const string inputs_var =
817               param_names_[arg_list->second.front()].GetRenameTo();
818           if (output_sizes.front().empty()) {
819             strings::StrAppend(&result_, indentation, var_name, ", (",
820                                inputs_var, ",) = ", conversion, "\n");
821           } else {
822             strings::StrAppend(&result_, indentation, var_name, ", ",
823                                inputs_var, " = ", conversion, "\n");
824           }
825         } else {
826           const string inputs_var = strings::StrCat("_inputs_", attr.name());
827           strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
828                              " = ", conversion, "\n");
829           // Convert from a flat list of eager tensors back to the
830           // parameter variables.
831           Unflatten(indentation, output_sizes, inputs_var, &result_);
832           std::vector<string> p;
833           for (int j : arg_list->second) {
834             p.emplace_back(param_names_[j].GetRenameTo());
835           }
836           strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
837                              inputs_var, "\n");
838         }
839       } else if (attr.type() == "list(type)") {
840         // NOTE: We ignore default values for these attrs, since it is
841         // unclear how you would use it, and the one use case is
842         // parse_single_sequence_example which only needs it for
843         // backwards compatibility.
844         const string var_name = AttrVarName(attr.name(), &attr_expressions_);
845         string inputs_var;
846         string conversion;
847         if (arg_list->second.size() > 1) {
848           // If you have more than one list(tensor) argument, their types
849           // have to match.
850           std::vector<string> lists;
851           for (auto iter = arg_list->second.begin();
852                iter != arg_list->second.end(); ++iter) {
853             lists.push_back(param_names_[*iter].GetRenameTo());
854           }
855           inputs_var = VectorToTuple(lists);
856           conversion = "_execute.args_to_mixed_eager_tensors";
857         } else {
858           // For one list(tensor) argument, we just convert every
859           // element of the list to an eager tensor.
860           inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
861           conversion = "_execute.convert_to_mixed_eager_tensors";
862         }
863         strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
864                            " = ", conversion, "(", inputs_var, ", _ctx)\n");
865       }
866     }
867   }
868 }
869 
AddEagerInputCasts(const string & indentation)870 void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
871   // Cast remaining args to eager tensors
872   for (int i = 0; i < op_def_.input_arg_size(); ++i) {
873     const auto& arg(op_def_.input_arg(i));
874     if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
875     const string& param = param_names_[i].GetRenameTo();
876     const string fn = arg.number_attr().empty() ? "" : "n_";
877     const string dtype =
878         python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
879     strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
880                        "to_tensor(", param, ", ", dtype, ")\n");
881   }
882 }
883 
AddEagerAttrs(const string & indentation)884 void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
885   // Compute eager attrs
886   if (op_def_.attr_size() > 0) {
887     string attr_values;
888     for (int i = 0; i < op_def_.attr_size(); ++i) {
889       if (i > 0) strings::StrAppend(&attr_values, ", ");
890       const auto& attr_name(op_def_.attr(i).name());
891       strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
892                          attr_expressions_[attr_name]);
893     }
894     strings::StrAppend(&attr_values, ")");
895     strings::StrAppend(
896         &result_,
897         WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
898                  kRightMargin),
899         "\n");
900   } else {
901     strings::StrAppend(&result_, indentation, "_attrs = None\n");
902   }
903 }
904 
AddEagerExecute(const string & indentation,const string & num_outputs_expr)905 void GenEagerPythonOp::AddEagerExecute(const string& indentation,
906                                        const string& num_outputs_expr) {
907   const string return_prefix =
908       strings::StrCat(indentation, "_result = _execute.execute(");
909   const string return_args = strings::StrCat(
910       "b\"", op_def_.name(), "\", ", num_outputs_expr,
911       ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)");
912   strings::StrAppend(&result_,
913                      // Wrap the arguments, and indent to the (.
914                      WordWrap(return_prefix, return_args, kRightMargin), "\n");
915 }
916 
AddDispatch(const string & prefix)917 void GenEagerPythonOp::AddDispatch(const string& prefix) {
918   if (api_def_.visibility() != ApiDef::VISIBLE) return;
919 
920   strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
921   strings::StrAppend(&result_, prefix, "  result = _dispatch.dispatch(\n");
922   AddBodyNoReturn(strings::StrCat(prefix, "        ", function_name_, ", "));
923   strings::StrAppend(&result_, prefix,
924                      "  if result is not "
925                      "_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
926   strings::StrAppend(&result_, prefix, "    return result\n");
927   strings::StrAppend(&result_, prefix, "  raise\n");
928 }
929 
AddRawOpExport(const string & parameters)930 void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
931   string arguments;
932   for (const auto& param_names : param_names_) {
933     const string renamed = param_names.GetRenameTo();
934     strings::StrAppend(&arguments, arguments.empty() ? "" : ", ", renamed, "=",
935                        renamed);
936   }
937   strings::StrAppend(&arguments, arguments.empty() ? "" : ", ", "name=name");
938 
939   const string raw_function_name =
940       python_op_gen_internal::AvoidPythonReserved(op_def_.name());
941 
942   strings::StrAppend(&result_, "def ", raw_function_name, "(", parameters,
943                      "):\n");
944   strings::StrAppend(&result_, "  return ", function_name_, "(", arguments,
945                      ")\n");
946 
947   // Copy the __doc__ from the original op and apply the decorators.
948   strings::StrAppend(&result_, raw_function_name, ".__doc__", " = ",
949                      function_name_, ".__doc__\n");
950   strings::StrAppend(&result_, raw_function_name, " = ",
951                      "_doc_controls.do_not_generate_docs(_kwarg_only(",
952                      raw_function_name, "))\n");
953 
954   // Export.
955   strings::StrAppend(&result_, "tf_export(\"raw_ops.", raw_function_name,
956                      "\")(", raw_function_name, ")\n");
957 }
958 
GetPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,bool require_shapes,const string & source_file_name="")959 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
960                     const std::vector<string>& hidden_ops, bool require_shapes,
961                     const string& source_file_name = "") {
962   string result;
963   // Header
964   // TODO(josh11b): Mention the library for which wrappers are being generated.
965   strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
966 
967 This file is MACHINE GENERATED! Do not edit.
968 )");
969 
970   // Mention the original source file so someone tracing back through
971   // generated Python code will know where to look next.
972   if (!source_file_name.empty()) {
973     strings::StrAppend(&result, "Original C++ source file: ");
974     strings::StrAppend(&result, source_file_name);
975     strings::StrAppend(&result, "\n");
976   }
977 
978   strings::StrAppend(&result, R"("""
979 
980 import collections as _collections
981 import six as _six
982 
983 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
984 from tensorflow.python.eager import context as _context
985 from tensorflow.python.eager import core as _core
986 from tensorflow.python.eager import execute as _execute
987 from tensorflow.python.framework import dtypes as _dtypes
988 from tensorflow.python.framework import errors as _errors
989 from tensorflow.python.framework import tensor_shape as _tensor_shape
990 
991 from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
992 # Needed to trigger the call to _set_call_cpp_shape_fn.
993 from tensorflow.python.framework import common_shapes as _common_shapes
994 from tensorflow.python.framework import op_def_registry as _op_def_registry
995 from tensorflow.python.framework import ops as _ops
996 from tensorflow.python.framework import op_def_library as _op_def_library
997 from tensorflow.python.util.deprecation import deprecated_endpoints
998 from tensorflow.python.util import dispatch as _dispatch
999 from tensorflow.python.util.tf_export import tf_export
1000 from tensorflow.python.util.tf_export import kwarg_only as _kwarg_only
1001 from tensorflow.tools.docs import doc_controls as _doc_controls
1002 
1003 )");
1004 
1005   // We'll make a copy of ops that filters out descriptions.
1006   OpList cleaned_ops;
1007   auto out = cleaned_ops.mutable_op();
1008   out->Reserve(ops.op_size());
1009   for (const auto& op_def : ops.op()) {
1010     const auto* api_def = api_defs.GetApiDef(op_def.name());
1011 
1012     if (api_def->visibility() == ApiDef::SKIP) {
1013       continue;
1014     }
1015     // An op is hidden if either its ApiDef visibility is HIDDEN
1016     // or it is in the hidden_ops list.
1017     bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
1018     bool hidden_by_api_def = is_hidden;
1019     if (!is_hidden) {
1020       for (const string& hidden : hidden_ops) {
1021         if (op_def.name() == hidden) {
1022           is_hidden = true;
1023           break;
1024         }
1025       }
1026     }
1027 
1028     string function_name;
1029     python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
1030                                                     &function_name);
1031     bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name);
1032 
1033     // Prefix an op with underscore if the op is listed in hidden_ops or
1034     // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix.
1035     // Do not add underscores to ops set to HIDDEN in ApiDef otherwise.
1036     // TODO(annarev): don't prefix with underscores even if op is in hidden_ops.
1037     if (is_hidden) {
1038       if (!hidden_by_api_def || is_reserved ||
1039           python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) {
1040         function_name = strings::StrCat("_", function_name);
1041       }
1042     } else if (is_reserved) {
1043       // When users create custom python wrappers, they may link in the
1044       // default op registry by accident, and because they can't
1045       // enumerate all 'hidden' symbols, this guard is to prevent
1046       // instantiating a python reserved word in their wrapper.
1047       continue;
1048     }
1049 
1050     strings::StrAppend(&result,
1051                        GetEagerPythonOp(op_def, *api_def, function_name));
1052 
1053     if (!require_shapes) {
1054       strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
1055                          "\")(None)\n\n");
1056     }
1057 
1058     auto added = out->Add();
1059     *added = op_def;
1060     RemoveNonDeprecationDescriptionsFromOpDef(added);
1061   }
1062 
1063   result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
1064   op_list = _op_def_pb2.OpList()
1065   op_list.ParseFromString(op_list_proto_bytes)
1066   _op_def_registry.register_op_list(op_list)
1067   op_def_lib = _op_def_library.OpDefLibrary()
1068   op_def_lib.add_op_list(op_list)
1069   return op_def_lib
1070 )");
1071 
1072   result.append("# ");
1073   auto ops_text = ProtoDebugString(cleaned_ops);
1074   str_util::StripTrailingWhitespace(&ops_text);
1075   result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
1076   result.append("\n");
1077   strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
1078                    str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
1079   return result;
1080 }
1081 
1082 }  // namespace
1083 
PrintPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,bool require_shapes,const string & source_file_name)1084 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1085                     const std::vector<string>& hidden_ops, bool require_shapes,
1086                     const string& source_file_name) {
1087   printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes,
1088                             source_file_name)
1089                    .c_str());
1090 }
1091 
GetPythonWrappers(const char * op_list_buf,size_t op_list_len)1092 string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
1093   string op_list_str(op_list_buf, op_list_len);
1094   OpList ops;
1095   ops.ParseFromString(op_list_str);
1096 
1097   ApiDefMap api_def_map(ops);
1098   return GetPythonOps(ops, api_def_map, {}, false);
1099 }
1100 
1101 }  // namespace tensorflow
1102