1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/python_op_gen.h"
16
17 #include <stdio.h>
18 #include <sstream>
19 #include <unordered_map>
20 #include "tensorflow/core/framework/api_def.pb.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_def.pb_text.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/framework/op_def_util.h"
26 #include "tensorflow/core/framework/op_gen_lib.h"
27 #include "tensorflow/core/framework/tensor.pb_text.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/lib/gtl/stl_util.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/lib/strings/stringprintf.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/python/framework/python_op_gen_internal.h"
39
40 namespace tensorflow {
41 namespace {
42
43 const int kRightMargin = 78;
44
45 constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
46
AttrVarName(const string & attr_name,std::unordered_map<string,string> * attr_expressions)47 string AttrVarName(const string& attr_name,
48 std::unordered_map<string, string>* attr_expressions) {
49 const string var = strings::StrCat("_attr_", attr_name);
50 if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
51 return var;
52 }
53
AddInferredAttr(const string & indentation,const string & attr_name,const string & value_expression,string * result,std::unordered_map<string,string> * attr_expressions)54 void AddInferredAttr(const string& indentation, const string& attr_name,
55 const string& value_expression, string* result,
56 std::unordered_map<string, string>* attr_expressions) {
57 strings::StrAppend(result, indentation,
58 AttrVarName(attr_name, attr_expressions), " = ",
59 value_expression, "\n");
60 }
61
VectorToTuple(const std::vector<string> & l)62 string VectorToTuple(const std::vector<string>& l) {
63 if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
64 string ret = "(";
65 for (int i = 0; i < l.size(); ++i) {
66 if (i > 0) {
67 strings::StrAppend(&ret, ", ");
68 }
69 strings::StrAppend(&ret, l[i]);
70 }
71 strings::StrAppend(&ret, ")");
72 return ret;
73 }
74
Unflatten(const string & prefix,const std::vector<string> & output_sizes,const string & var,string * result)75 void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
76 const string& var, string* result) {
77 for (int i = 0; i < output_sizes.size(); ++i) {
78 if (!output_sizes[i].empty()) {
79 strings::StrAppend(result, prefix, var, " = ");
80 if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
81 if (i + 1 < output_sizes.size()) {
82 // Special case i == 0 to avoid "0 +" in the generated code.
83 if (i == 0) {
84 strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
85 var, "[", output_sizes[i], ":]");
86 } else {
87 strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
88 output_sizes[i], "]] + ", var, "[", i, " + ",
89 output_sizes[i], ":]");
90 }
91 } else {
92 strings::StrAppend(result, "[", var, "[", i, ":]]");
93 }
94 strings::StrAppend(result, "\n");
95 }
96 }
97 }
98
TensorPBString(const TensorProto & pb)99 string TensorPBString(const TensorProto& pb) {
100 // Note: This gets used in the argument list, and so must survive naive
101 // word wrapping.
102 return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
103 }
104
105 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
106 public:
GenEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name)107 GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
108 const string& function_name)
109 : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
110 op_name_ = function_name_;
111 str_util::ConsumePrefix(&op_name_, "_");
112 }
~GenEagerPythonOp()113 ~GenEagerPythonOp() override {}
114
115 string Code() override;
116
117 protected:
118 void HandleGraphMode(const string& function_setup);
119
120 string GetEagerNotAllowedError();
121 void ExpectListArg(const string& indentation, const string& arg_name,
122 string* output);
123 bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
124 void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
125 string* num_outputs_expr);
126
127 void AddEagerFunctionTeardown(const string& indentation,
128 const std::vector<string>& output_sizes,
129 bool execute_record_gradient);
130
131 bool AddEagerFastPathAndGraphCode(const string& parameters,
132 const std::vector<string>& output_sizes,
133 const string& eager_not_allowed_error);
134 bool AddEagerFallbackCode(const string& parameters,
135 const std::vector<string>& output_sizes,
136 const string& num_outputs_expr,
137 const string& eager_not_allowed_error);
138 void AddEagerFastPathExecute();
139
140 void AddEagerInferredAttrs(const string& indentation);
141 void AddEagerInputCasts(const string& indentation);
142 void AddEagerAttrs(const string& indentation);
143 void AddEagerExecute(const string& indentation,
144 const string& num_outputs_expr);
145 void AddDispatch(const string& prefix);
146
147 void AddRawOpExport(const string& parameters);
148
AddAttrForArg(const string & attr,int arg_index)149 void AddAttrForArg(const string& attr, int arg_index) {
150 gtl::InsertIfNotPresent(&inferred_attrs_, attr,
151 op_def_.input_arg(arg_index).name());
152 auto iter = attr_to_args_.find(attr);
153 if (iter == attr_to_args_.end()) {
154 attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
155 } else {
156 iter->second.push_back(arg_index);
157 }
158 }
159
160 // Returns a string expression representing a flattened list of all
161 // the inputs given by `*input_indices` (or all inputs if
162 // `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
163 string FlattenInputs(const std::vector<int>* input_indices,
164 std::vector<string>* output_sizes) const;
165
166 StringPiece op_name_;
167 typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
168 AttrToArgMap attr_to_args_;
169 std::unordered_map<string, string> attr_expressions_;
170 // This has all the input args followed by those attrs that don't have
171 // defaults.
172 std::vector<python_op_gen_internal::ParamNames> params_no_default_;
173 // The parameters with defaults (these have to be listed after those without).
174 // No input args are included, just attrs.
175 std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
176 params_with_default_;
177 };
178
GetEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name)179 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
180 const string& function_name) {
181 return GenEagerPythonOp(op_def, api_def, function_name).Code();
182 }
183
FlattenInputs(const std::vector<int> * input_indices,std::vector<string> * output_sizes) const184 string GenEagerPythonOp::FlattenInputs(
185 const std::vector<int>* input_indices,
186 std::vector<string>* output_sizes) const {
187 string inputs;
188 enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
189 const int n = input_indices != nullptr ? input_indices->size()
190 : op_def_.input_arg_size();
191 for (int j = 0; j < n; ++j) {
192 const int i = input_indices ? (*input_indices)[j] : j;
193 const auto& arg(op_def_.input_arg(i));
194 const bool is_list =
195 !arg.type_list_attr().empty() || !arg.number_attr().empty();
196 if (is_list) {
197 if (inputs_state == WAS_SOLO_INPUT) {
198 strings::StrAppend(&inputs, "] + ");
199 } else if (inputs_state == WAS_LIST_INPUT) {
200 strings::StrAppend(&inputs, " + ");
201 }
202 strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
203 inputs_state = WAS_LIST_INPUT;
204 if (output_sizes != nullptr) {
205 if (!arg.number_attr().empty()) {
206 output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
207 } else {
208 output_sizes->emplace_back(
209 strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
210 }
211 }
212 } else {
213 if (inputs_state == WAS_SOLO_INPUT) {
214 strings::StrAppend(&inputs, ", ");
215 } else if (inputs_state == WAS_LIST_INPUT) {
216 strings::StrAppend(&inputs, " + [");
217 } else {
218 strings::StrAppend(&inputs, "[");
219 }
220 strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
221 inputs_state = WAS_SOLO_INPUT;
222 if (output_sizes != nullptr) output_sizes->emplace_back();
223 }
224 }
225 if (inputs_state == STARTING) return "[]";
226 if (inputs_state == WAS_SOLO_INPUT) {
227 strings::StrAppend(&inputs, "]");
228 }
229 return inputs;
230 }
231
Code()232 string GenEagerPythonOp::Code() {
233 if (api_def_.visibility() == ApiDef::SKIP) {
234 return "";
235 }
236
237 for (int i = 0; i < api_def_.arg_order_size(); ++i) {
238 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
239 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
240 params_no_default_.emplace_back(api_def_arg.name(),
241 api_def_arg.rename_to());
242 if (!arg.type_attr().empty()) {
243 AddAttrForArg(arg.type_attr(), i);
244 } else if (!arg.type_list_attr().empty()) {
245 AddAttrForArg(arg.type_list_attr(), i);
246 }
247 if (!arg.number_attr().empty()) {
248 AddAttrForArg(arg.number_attr(), i);
249 }
250 }
251 for (int i = 0; i < op_def_.attr_size(); ++i) {
252 const auto& attr(op_def_.attr(i));
253 const auto& api_def_attr(api_def_.attr(i));
254 // Do not add inferred attrs to the Python function signature.
255 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
256 if (api_def_attr.has_default_value()) {
257 if (attr.type() == "tensor") {
258 params_with_default_.emplace_back(
259 python_op_gen_internal::ParamNames(api_def_attr.name(),
260 api_def_attr.rename_to()),
261 strings::StrCat(
262 "_execute.make_tensor(",
263 TensorPBString(api_def_attr.default_value().tensor()), ", \"",
264 api_def_attr.rename_to(), "\")"));
265 } else if (attr.type() == "list(tensor)") {
266 std::vector<string> pbtxt;
267 for (const auto& pb : api_def_attr.default_value().list().tensor()) {
268 pbtxt.emplace_back(TensorPBString(pb));
269 }
270 params_with_default_.emplace_back(
271 python_op_gen_internal::ParamNames(api_def_attr.name(),
272 api_def_attr.rename_to()),
273 strings::StrCat("[_execute.make_tensor(_pb, \"",
274 api_def_attr.rename_to(), "\") for _pb in ",
275 VectorToTuple(pbtxt), "]"));
276 } else {
277 params_with_default_.emplace_back(
278 python_op_gen_internal::ParamNames(api_def_attr.name(),
279 api_def_attr.rename_to()),
280 python_op_gen_internal::AttrValueToPython(
281 attr.type(), api_def_attr.default_value(), "_dtypes."));
282 }
283 } else {
284 params_no_default_.emplace_back(api_def_attr.name(),
285 api_def_attr.rename_to());
286 }
287 }
288 }
289
290 // Save the list of attr parameters (attrs that won't be inferred),
291 // those with defaults go at the end.
292 // Get the attrs in the order we want by taking the attrs without defaults
293 // from the end of params_no_default_, and adding params_no_default_.
294 attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
295 params_with_default_.size());
296 for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) {
297 attrs_.push_back(params_no_default_[i].GetName());
298 }
299 for (const auto& p : params_with_default_) {
300 attrs_.push_back(p.first.GetName());
301 }
302
303 // TODO(slebedev): call AvoidPythonReserved on each param?
304 param_names_.reserve(params_no_default_.size() + params_with_default_.size());
305 param_names_.insert(param_names_.begin(), params_no_default_.begin(),
306 params_no_default_.end());
307 for (const auto& param_and_default : params_with_default_) {
308 param_names_.push_back(param_and_default.first);
309 }
310
311 string parameters;
312 for (const auto& param : params_no_default_) {
313 if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
314 strings::StrAppend(¶meters, param.GetRenameTo());
315 }
316 for (const auto& param_and_default : params_with_default_) {
317 if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
318 strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), "=",
319 param_and_default.second);
320 }
321 strings::StrAppend(¶meters, parameters.empty() ? "" : ", ", "name=None");
322
323 // Add attr_expressions_ for attrs that are params.
324 for (int i = 0; i < attrs_.size(); ++i) {
325 const string& attr_name = attrs_[i];
326 const string& attr_api_name =
327 param_names_[i + op_def_.input_arg_size()].GetRenameTo();
328 attr_expressions_[attr_name] = attr_api_name;
329 }
330 // Add attr_expressions_ for attrs that are inferred.
331 for (int i = 0; i < op_def_.attr_size(); ++i) {
332 const auto& attr(op_def_.attr(i));
333 if (attr.type() == "int") {
334 auto arg_list = attr_to_args_.find(attr.name());
335 if (arg_list != attr_to_args_.end()) {
336 AttrVarName(attr.name(), &attr_expressions_);
337 }
338 }
339 }
340
341 string num_outputs_expr;
342 std::vector<string> output_sizes(num_outs_);
343 GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
344
345 string eager_not_allowed_error = GetEagerNotAllowedError();
346
347 if (!AddEagerFastPathAndGraphCode(parameters, output_sizes,
348 eager_not_allowed_error)) {
349 return result_;
350 }
351
352 if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
353 eager_not_allowed_error)) {
354 return result_;
355 }
356
357 return prelude_ + result_;
358 }
359
HandleGraphMode(const string & function_setup)360 void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
361 strings::StrAppend(&result_, " # Add nodes to the TensorFlow graph.\n");
362 strings::StrAppend(&result_, function_setup);
363 if (api_def_.visibility() == ApiDef::VISIBLE) {
364 strings::StrAppend(&result_, " try:\n ");
365 }
366 strings::StrAppend(&result_, " _, _, _op = _op_def_lib._apply_op_helper(\n");
367 AddBodyNoReturn(strings::StrCat(" \"", op_def_.name(), "\", "));
368 AddDispatch(" ");
369
370 if (num_outs_ > 0) {
371 strings::StrAppend(&result_, " _result = _op.outputs[:]\n");
372 // Special case handling for stateful op with single list output
373 // that might be empty.
374 if (num_outs_ == 1 && op_def_.is_stateful() &&
375 (!op_def_.output_arg(0).number_attr().empty() ||
376 !op_def_.output_arg(0).type_list_attr().empty())) {
377 // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
378 // a constraint indicating that this can never be empty.
379 strings::StrAppend(&result_,
380 " if not _result:\n"
381 " return _op\n");
382 }
383 strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n");
384
385 // Compute graph-mode attrs.
386 if (op_def_.attr_size() > 0) {
387 string attr_values;
388 for (int i = 0; i < op_def_.attr_size(); ++i) {
389 if (i > 0) strings::StrAppend(&attr_values, ", ");
390 const auto& attr_name(op_def_.attr(i).name());
391 strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
392 attr_name, "\")");
393 }
394 strings::StrAppend(&attr_values, ")");
395 strings::StrAppend(
396 &result_, WordWrap(" _attrs = (", attr_values, kRightMargin), "\n");
397 } else {
398 strings::StrAppend(&result_, " _attrs = None\n");
399 }
400 } else {
401 strings::StrAppend(&result_, " return _op\n");
402 }
403 }
404
GetEagerNotAllowedError()405 string GenEagerPythonOp::GetEagerNotAllowedError() {
406 bool eager_allowed = true;
407 string ref_arg;
408 for (int i = 0; i < op_def_.input_arg_size(); ++i) {
409 const auto& arg = op_def_.input_arg(i);
410 if (arg.is_ref()) {
411 eager_allowed = false;
412 DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
413 ref_arg = api_def_.in_arg(i).rename_to();
414 }
415 }
416 for (int i = 0; i < op_def_.output_arg_size(); ++i) {
417 const auto& arg = op_def_.output_arg(i);
418 if (arg.is_ref()) {
419 eager_allowed = false;
420 DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
421 ref_arg = api_def_.out_arg(i).rename_to();
422 }
423 }
424
425 if (eager_allowed) return "";
426
427 return strings::StrCat("raise RuntimeError(\"", op_name_,
428 " op does not support eager execution. ", "Arg '",
429 ref_arg, "' is a ref.\")\n");
430 }
431
ExpectListArg(const string & indentation,const string & arg_name,string * output)432 void GenEagerPythonOp::ExpectListArg(const string& indentation,
433 const string& arg_name, string* output) {
434 strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
435 ", (list, tuple)):\n", indentation, " raise TypeError(\n",
436 indentation, " \"Expected list for '", arg_name,
437 "' argument to \"\n", indentation, " \"'", op_name_,
438 "' Op, not %r.\" % ", arg_name, ")\n");
439 }
440
GetEagerFunctionSetup(const string & indentation,string * function_setup)441 bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
442 string* function_setup) {
443 // Validate list inputs, infer length attrs.
444 for (int i = 0; i < op_def_.attr_size(); ++i) {
445 const auto& attr(op_def_.attr(i));
446 if (attr.type() == "int") {
447 auto arg_list = attr_to_args_.find(attr.name());
448 if (arg_list != attr_to_args_.end()) {
449 // Inferred int attrs are the lengths of inputs. Validate those
450 // inputs are lists and have the same length.
451 for (auto iter = arg_list->second.begin();
452 iter != arg_list->second.end(); ++iter) {
453 const string& arg_api_name = param_names_[*iter].GetRenameTo();
454 ExpectListArg(indentation, arg_api_name, function_setup);
455 if (iter == arg_list->second.begin()) {
456 AddInferredAttr(indentation, attr.name(),
457 strings::StrCat("len(", arg_api_name, ")"),
458 function_setup, &attr_expressions_);
459 } else {
460 const auto& attr_var = attr_expressions_[attr.name()];
461 strings::StrAppend(
462 function_setup, indentation, "if len(", arg_api_name,
463 ") != ", attr_var, ":\n", indentation, " raise ValueError(\n",
464 indentation, " \"List argument '", arg_api_name, "' to '",
465 op_name_, "' Op with length %d \"\n", indentation,
466 " \"must match length %d of argument '",
467 inferred_attrs_[attr.name()], "'.\" %\n", indentation,
468 " (len(", arg_api_name, "), ", attr_var, "))\n");
469 }
470 }
471 }
472 }
473 }
474
475 for (int i = 0; i < attrs_.size(); ++i) {
476 const string& attr_name = attrs_[i];
477 const auto& param = param_names_[i + op_def_.input_arg_size()];
478 const auto& attr = *FindAttr(attr_name, op_def_);
479 const string& attr_api_name = param.GetRenameTo();
480 StringPiece attr_type = attr.type();
481 attr_expressions_[attr_name] = attr_api_name;
482 const int default_index = i - (attrs_.size() - params_with_default_.size());
483 if (default_index >= 0) {
484 const string& default_value = params_with_default_[default_index].second;
485 strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
486 " is None:\n");
487 strings::StrAppend(function_setup, indentation, " ", attr_api_name,
488 " = ", default_value, "\n");
489 }
490 if (str_util::StartsWith(attr_type, "list(")) {
491 ExpectListArg(indentation, attr_api_name, function_setup);
492 }
493
494 if (attr_type == "string") {
495 strings::StrAppend(function_setup, indentation, attr_api_name,
496 " = _execute.make_str(", attr_api_name, ", \"",
497 attr_api_name, "\")\n");
498 } else if (attr_type == "list(string)") {
499 strings::StrAppend(function_setup, indentation, attr_api_name,
500 " = [_execute.make_str(_s, \"", attr_api_name,
501 "\") for _s in ", attr_api_name, "]\n");
502 } else if (attr_type == "int") {
503 strings::StrAppend(function_setup, indentation, attr_api_name,
504 " = _execute.make_int(", attr_api_name, ", \"",
505 attr_api_name, "\")\n");
506 } else if (attr_type == "list(int)") {
507 strings::StrAppend(function_setup, indentation, attr_api_name,
508 " = [_execute.make_int(_i, \"", attr_api_name,
509 "\") for _i in ", attr_api_name, "]\n");
510 } else if (attr_type == "float") {
511 strings::StrAppend(function_setup, indentation, attr_api_name,
512 " = _execute.make_float(", attr_api_name, ", \"",
513 attr_api_name, "\")\n");
514 } else if (attr_type == "list(float)") {
515 strings::StrAppend(function_setup, indentation, attr_api_name,
516 " = [_execute.make_float(_f, \"", attr_api_name,
517 "\") for _f in ", attr_api_name, "]\n");
518 } else if (attr_type == "bool") {
519 strings::StrAppend(function_setup, indentation, attr_api_name,
520 " = _execute.make_bool(", attr_api_name, ", \"",
521 attr_api_name, "\")\n");
522 } else if (attr_type == "list(bool)") {
523 strings::StrAppend(function_setup, indentation, attr_api_name,
524 " = [_execute.make_bool(_b, \"", attr_api_name,
525 "\") for _b in ", attr_api_name, "]\n");
526 } else if (attr_type == "type") {
527 strings::StrAppend(function_setup, indentation, attr_api_name,
528 " = _execute.make_type(", attr_api_name, ", \"",
529 attr_api_name, "\")\n");
530 } else if (attr_type == "list(type)") {
531 strings::StrAppend(function_setup, indentation, attr_api_name,
532 " = [_execute.make_type(_t, \"", attr_api_name,
533 "\") for _t in ", attr_api_name, "]\n");
534 } else if (attr_type == "shape") {
535 strings::StrAppend(function_setup, indentation, attr_api_name,
536 " = _execute.make_shape(", attr_api_name, ", \"",
537 attr_api_name, "\")\n");
538 } else if (attr_type == "list(shape)") {
539 strings::StrAppend(function_setup, indentation, attr_api_name,
540 " = [_execute.make_shape(_s, \"", attr_api_name,
541 "\") for _s in ", attr_api_name, "]\n");
542 } else if (attr_type == "tensor") {
543 strings::StrAppend(function_setup, indentation, attr_api_name,
544 " = _execute.make_tensor(", attr_api_name, ", \"",
545 attr_api_name, "\")\n");
546 } else if (attr_type == "list(tensor)") {
547 strings::StrAppend(function_setup, indentation, attr_api_name,
548 " = [_execute.make_tensor(_t, \"", attr_api_name,
549 "\") for _t in ", attr_api_name, "]\n");
550 } else if (attr_type != "func" && attr_type != "list(func)") {
551 *function_setup =
552 strings::StrCat("# No definition for ", function_name_,
553 " since we don't support attrs with type\n"
554 "# '",
555 attr_type, "' right now.\n\n");
556 return false;
557 }
558 }
559 return true;
560 }
561
562 // If output i is list output, output_sizes[i] will be set to a
563 // string with the python expression that will evaluate to its
564 // length. output_sizes[i] is empty for non-list outputs.
GetOutputSizesAndNumOutputsExpr(std::vector<string> * output_sizes,string * num_outputs_expr)565 void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
566 std::vector<string>* output_sizes, string* num_outputs_expr) {
567 // Expression representing the number of outputs.
568 int num_fixed_outputs = 0;
569 for (int i = 0; i < num_outs_; ++i) {
570 const auto& arg(op_def_.output_arg(i));
571 if (!arg.number_attr().empty()) {
572 if (!num_outputs_expr->empty()) {
573 strings::StrAppend(num_outputs_expr, " + ");
574 }
575 (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
576 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
577 } else if (!arg.type_list_attr().empty()) {
578 if (!num_outputs_expr->empty()) {
579 strings::StrAppend(num_outputs_expr, " + ");
580 }
581 // Have to be careful to use an expression that works in both
582 // graph and eager paths here.
583 const auto iter = inferred_attrs_.find(arg.type_list_attr());
584 if (iter == inferred_attrs_.end()) {
585 (*output_sizes)[i] = strings::StrCat(
586 "len(", attr_expressions_[arg.type_list_attr()], ")");
587 } else {
588 (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
589 }
590 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
591 } else {
592 ++num_fixed_outputs;
593 }
594 }
595 if (num_fixed_outputs > 0) {
596 if (!num_outputs_expr->empty()) {
597 strings::StrAppend(num_outputs_expr, " + ");
598 }
599 strings::StrAppend(num_outputs_expr, num_fixed_outputs);
600 } else if (num_outputs_expr->empty()) {
601 *num_outputs_expr = "0";
602 }
603 }
604
AddEagerFunctionTeardown(const string & indentation,const std::vector<string> & output_sizes,bool execute_record_gradient)605 void GenEagerPythonOp::AddEagerFunctionTeardown(
606 const string& indentation, const std::vector<string>& output_sizes,
607 bool execute_record_gradient) {
608 if (num_outs_ > 0) {
609 if (execute_record_gradient) {
610 strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n",
611 " \"", op_def_.name(),
612 "\", _inputs_flat, _attrs, _result, name)\n");
613 }
614 if (num_outs_ == 1 && !output_sizes[0].empty()) {
615 // Single list result.
616 } else if (num_outs_ == 1) {
617 // Execute returns a single-element list which we need to destructure.
618 strings::StrAppend(&result_, indentation, "_result, = _result\n");
619 } else {
620 // Have multiple outputs, so we will need to reformat the return
621 // value of execute() to be a list with one entry per op output
622 // (that entry will be a list of tensors if that output is of list
623 // type).
624 // For list outputs, convert the right subrange of _result into a list.
625 Unflatten(indentation, output_sizes, "_result", &result_);
626 // Convert to a named tuple.
627 strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(),
628 "Output._make(_result)\n");
629 }
630 } else {
631 strings::StrAppend(&result_, indentation, "_result = None\n");
632 }
633 strings::StrAppend(&result_, indentation, "return _result\n\n");
634 }
635
AddEagerFastPathAndGraphCode(const string & parameters,const std::vector<string> & output_sizes,const string & eager_not_allowed_error)636 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
637 const string& parameters, const std::vector<string>& output_sizes,
638 const string& eager_not_allowed_error) {
639 if (api_def_.visibility() == ApiDef::VISIBLE) {
640 strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
641 }
642
643 AddExport();
644 AddDefLine(function_name_, parameters);
645 AddDocStringDescription();
646 AddDocStringArgs();
647 AddDocStringInputs();
648 AddDocStringAttrs();
649 AddDocStringNameArg();
650 AddOutputGlobals(); // Added to prelude_
651 AddDocStringOutputs();
652 strings::StrAppend(&result_, " \"\"\"\n");
653
654 strings::StrAppend(
655 &result_,
656 " _ctx = _context._context or _context.context()\n"
657 " if _ctx is not None and _ctx._thread_local_data.is_eager:",
658 "\n");
659 if (eager_not_allowed_error.empty()) {
660 AddEagerFastPathExecute();
661 } else {
662 strings::StrAppend(&result_, " ", eager_not_allowed_error);
663 }
664
665 // Handle graph-mode case
666 string function_setup;
667 if (!GetEagerFunctionSetup(" ", &function_setup)) {
668 result_ = function_setup;
669 return false;
670 }
671 HandleGraphMode(function_setup);
672 AddEagerFunctionTeardown(" ", output_sizes,
673 true /* execute_record_gradient */);
674
675 AddRawOpExport(parameters);
676 strings::StrAppend(&result_, "\n\n");
677 return true;
678 }
679
AddEagerFallbackCode(const string & parameters,const std::vector<string> & output_sizes,const string & num_outputs_expr,const string & eager_not_allowed_error)680 bool GenEagerPythonOp::AddEagerFallbackCode(
681 const string& parameters, const std::vector<string>& output_sizes,
682 const string& num_outputs_expr, const string& eager_not_allowed_error) {
683 AddDefLine(
684 strings::StrCat(function_name_, kEagerFallbackSuffix),
685 strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx=None"));
686
687 if (!eager_not_allowed_error.empty()) {
688 strings::StrAppend(&result_, " ", eager_not_allowed_error);
689 return true;
690 }
691
692 strings::StrAppend(
693 &result_, " r\"\"\"This is the slowpath function for Eager mode.\n");
694 strings::StrAppend(&result_, " This is for function ", function_name_,
695 "\n \"\"\"\n");
696
697 strings::StrAppend(&result_, " _ctx = ctx if ctx else _context.context()\n");
698
699 string function_setup;
700 if (!GetEagerFunctionSetup(" ", &function_setup)) {
701 result_ = function_setup;
702 return false;
703 }
704 strings::StrAppend(&result_, function_setup);
705
706 AddEagerInferredAttrs(" ");
707 AddEagerInputCasts(" ");
708 strings::StrAppend(
709 &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
710 AddEagerAttrs(" ");
711 AddEagerExecute(" ", num_outputs_expr);
712
713 AddEagerFunctionTeardown(" ", output_sizes,
714 true /* execute_record_gradient */);
715
716 return true;
717 }
718
AddEagerFastPathExecute()719 void GenEagerPythonOp::AddEagerFastPathExecute() {
720 string fastpath_execute_params = strings::StrCat(
721 "_ctx._context_handle, _ctx._thread_local_data.device_name, \"",
722 op_def_.name(), "\", ", "name, _ctx._post_execution_callbacks");
723 string fallback_params;
724
725 for (int i = 0; i < api_def_.in_arg_size(); i++) {
726 const string param_name = param_names_[i].GetRenameTo();
727 strings::StrAppend(&fastpath_execute_params, ", ", param_name);
728 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
729 strings::StrAppend(&fallback_params, param_name);
730 }
731
732 for (const auto& attr : api_def_.attr()) {
733 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
734 strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
735 attr.rename_to());
736
737 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
738 strings::StrAppend(&fallback_params, attr.rename_to(), "=",
739 attr.rename_to());
740 }
741 }
742
743 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
744 strings::StrAppend(&fallback_params, "name=name");
745
746 strings::StrAppend(&result_, " try:\n");
747 strings::StrAppend(
748 &result_, " ",
749 "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
750 WordWrap(strings::StrCat(" "),
751 strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
752 "\n");
753
754 if (op_def_.output_arg_size() > 1) {
755 const string output_tuple_name =
756 strings::StrCat("_", op_def_.name(), "Output");
757 strings::StrAppend(&result_, " ", "_result = ", output_tuple_name,
758 "._make(_result)\n");
759 }
760 strings::StrAppend(&result_, " ", "return _result\n");
761
762 // Handle fallback.
763 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
764 strings::StrAppend(&fallback_params, "ctx=_ctx");
765 strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
766 strings::StrAppend(&result_, " try:\n");
767 strings::StrAppend(
768 &result_, " ", "return ", function_name_, kEagerFallbackSuffix,
769 "(\n",
770 WordWrap(strings::StrCat(" "),
771 strings::StrCat(fallback_params, ")"), kRightMargin),
772 "\n");
773 strings::StrAppend(&result_, " except _core._SymbolicException:\n");
774 strings::StrAppend(&result_,
775 " pass # Add nodes to the TensorFlow graph.\n");
776 AddDispatch(" ");
777
778 // Any errors thrown from execute need to be unwrapped from
779 // _NotOkStatusException.
780 strings::StrAppend(&result_, " ",
781 "except _core._NotOkStatusException as e:\n");
782 strings::StrAppend(&result_, " ", "if name is not None:\n");
783 strings::StrAppend(&result_, " ",
784 "message = e.message + \" name: \" + name\n");
785 strings::StrAppend(&result_, " ", "else:\n");
786 strings::StrAppend(&result_, " ", "message = e.message\n");
787 strings::StrAppend(
788 &result_, " ",
789 "_six.raise_from(_core._status_to_exception(e.code, message), None)\n");
790 }
791
AddEagerInferredAttrs(const string & indentation)792 void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
793 // Figure out values for inferred attrs, and cast to eager tensors.
794 for (int i = 0; i < op_def_.attr_size(); ++i) {
795 const auto& attr(op_def_.attr(i));
796 const auto& api_def_attr(api_def_.attr(i));
797 auto arg_list = attr_to_args_.find(attr.name());
798 if (arg_list != attr_to_args_.end()) {
799 if (attr.type() == "type") {
800 std::vector<string> output_sizes;
801 const string flattened =
802 FlattenInputs(&arg_list->second, &output_sizes);
803 string conversion = strings::StrCat("_execute.args_to_matching_eager(",
804 flattened, ", _ctx");
805 if (attr.has_default_value()) {
806 strings::StrAppend(
807 &conversion, ", ",
808 python_op_gen_internal::AttrValueToPython(
809 attr.type(), api_def_attr.default_value(), "_dtypes."));
810 }
811 strings::StrAppend(&conversion, ")");
812 const string var_name = AttrVarName(attr.name(), &attr_expressions_);
813 if (output_sizes.size() == 1) {
814 // Avoid creating a temporary variable in the case where
815 // we can easily assign to the right value directly.
816 const string inputs_var =
817 param_names_[arg_list->second.front()].GetRenameTo();
818 if (output_sizes.front().empty()) {
819 strings::StrAppend(&result_, indentation, var_name, ", (",
820 inputs_var, ",) = ", conversion, "\n");
821 } else {
822 strings::StrAppend(&result_, indentation, var_name, ", ",
823 inputs_var, " = ", conversion, "\n");
824 }
825 } else {
826 const string inputs_var = strings::StrCat("_inputs_", attr.name());
827 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
828 " = ", conversion, "\n");
829 // Convert from a flat list of eager tensors back to the
830 // parameter variables.
831 Unflatten(indentation, output_sizes, inputs_var, &result_);
832 std::vector<string> p;
833 for (int j : arg_list->second) {
834 p.emplace_back(param_names_[j].GetRenameTo());
835 }
836 strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
837 inputs_var, "\n");
838 }
839 } else if (attr.type() == "list(type)") {
840 // NOTE: We ignore default values for these attrs, since it is
841 // unclear how you would use it, and the one use case is
842 // parse_single_sequence_example which only needs it for
843 // backwards compatibility.
844 const string var_name = AttrVarName(attr.name(), &attr_expressions_);
845 string inputs_var;
846 string conversion;
847 if (arg_list->second.size() > 1) {
848 // If you have more than one list(tensor) argument, their types
849 // have to match.
850 std::vector<string> lists;
851 for (auto iter = arg_list->second.begin();
852 iter != arg_list->second.end(); ++iter) {
853 lists.push_back(param_names_[*iter].GetRenameTo());
854 }
855 inputs_var = VectorToTuple(lists);
856 conversion = "_execute.args_to_mixed_eager_tensors";
857 } else {
858 // For one list(tensor) argument, we just convert every
859 // element of the list to an eager tensor.
860 inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
861 conversion = "_execute.convert_to_mixed_eager_tensors";
862 }
863 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
864 " = ", conversion, "(", inputs_var, ", _ctx)\n");
865 }
866 }
867 }
868 }
869
AddEagerInputCasts(const string & indentation)870 void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
871 // Cast remaining args to eager tensors
872 for (int i = 0; i < op_def_.input_arg_size(); ++i) {
873 const auto& arg(op_def_.input_arg(i));
874 if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
875 const string& param = param_names_[i].GetRenameTo();
876 const string fn = arg.number_attr().empty() ? "" : "n_";
877 const string dtype =
878 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
879 strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
880 "to_tensor(", param, ", ", dtype, ")\n");
881 }
882 }
883
AddEagerAttrs(const string & indentation)884 void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
885 // Compute eager attrs
886 if (op_def_.attr_size() > 0) {
887 string attr_values;
888 for (int i = 0; i < op_def_.attr_size(); ++i) {
889 if (i > 0) strings::StrAppend(&attr_values, ", ");
890 const auto& attr_name(op_def_.attr(i).name());
891 strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
892 attr_expressions_[attr_name]);
893 }
894 strings::StrAppend(&attr_values, ")");
895 strings::StrAppend(
896 &result_,
897 WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
898 kRightMargin),
899 "\n");
900 } else {
901 strings::StrAppend(&result_, indentation, "_attrs = None\n");
902 }
903 }
904
AddEagerExecute(const string & indentation,const string & num_outputs_expr)905 void GenEagerPythonOp::AddEagerExecute(const string& indentation,
906 const string& num_outputs_expr) {
907 const string return_prefix =
908 strings::StrCat(indentation, "_result = _execute.execute(");
909 const string return_args = strings::StrCat(
910 "b\"", op_def_.name(), "\", ", num_outputs_expr,
911 ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)");
912 strings::StrAppend(&result_,
913 // Wrap the arguments, and indent to the (.
914 WordWrap(return_prefix, return_args, kRightMargin), "\n");
915 }
916
AddDispatch(const string & prefix)917 void GenEagerPythonOp::AddDispatch(const string& prefix) {
918 if (api_def_.visibility() != ApiDef::VISIBLE) return;
919
920 strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
921 strings::StrAppend(&result_, prefix, " result = _dispatch.dispatch(\n");
922 AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_, ", "));
923 strings::StrAppend(&result_, prefix,
924 " if result is not "
925 "_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
926 strings::StrAppend(&result_, prefix, " return result\n");
927 strings::StrAppend(&result_, prefix, " raise\n");
928 }
929
AddRawOpExport(const string & parameters)930 void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
931 string arguments;
932 for (const auto& param_names : param_names_) {
933 const string renamed = param_names.GetRenameTo();
934 strings::StrAppend(&arguments, arguments.empty() ? "" : ", ", renamed, "=",
935 renamed);
936 }
937 strings::StrAppend(&arguments, arguments.empty() ? "" : ", ", "name=name");
938
939 const string raw_function_name =
940 python_op_gen_internal::AvoidPythonReserved(op_def_.name());
941
942 strings::StrAppend(&result_, "def ", raw_function_name, "(", parameters,
943 "):\n");
944 strings::StrAppend(&result_, " return ", function_name_, "(", arguments,
945 ")\n");
946
947 // Copy the __doc__ from the original op and apply the decorators.
948 strings::StrAppend(&result_, raw_function_name, ".__doc__", " = ",
949 function_name_, ".__doc__\n");
950 strings::StrAppend(&result_, raw_function_name, " = ",
951 "_doc_controls.do_not_generate_docs(_kwarg_only(",
952 raw_function_name, "))\n");
953
954 // Export.
955 strings::StrAppend(&result_, "tf_export(\"raw_ops.", raw_function_name,
956 "\")(", raw_function_name, ")\n");
957 }
958
GetPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,bool require_shapes,const string & source_file_name="")959 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
960 const std::vector<string>& hidden_ops, bool require_shapes,
961 const string& source_file_name = "") {
962 string result;
963 // Header
964 // TODO(josh11b): Mention the library for which wrappers are being generated.
965 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
966
967 This file is MACHINE GENERATED! Do not edit.
968 )");
969
970 // Mention the original source file so someone tracing back through
971 // generated Python code will know where to look next.
972 if (!source_file_name.empty()) {
973 strings::StrAppend(&result, "Original C++ source file: ");
974 strings::StrAppend(&result, source_file_name);
975 strings::StrAppend(&result, "\n");
976 }
977
978 strings::StrAppend(&result, R"("""
979
980 import collections as _collections
981 import six as _six
982
983 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
984 from tensorflow.python.eager import context as _context
985 from tensorflow.python.eager import core as _core
986 from tensorflow.python.eager import execute as _execute
987 from tensorflow.python.framework import dtypes as _dtypes
988 from tensorflow.python.framework import errors as _errors
989 from tensorflow.python.framework import tensor_shape as _tensor_shape
990
991 from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
992 # Needed to trigger the call to _set_call_cpp_shape_fn.
993 from tensorflow.python.framework import common_shapes as _common_shapes
994 from tensorflow.python.framework import op_def_registry as _op_def_registry
995 from tensorflow.python.framework import ops as _ops
996 from tensorflow.python.framework import op_def_library as _op_def_library
997 from tensorflow.python.util.deprecation import deprecated_endpoints
998 from tensorflow.python.util import dispatch as _dispatch
999 from tensorflow.python.util.tf_export import tf_export
1000 from tensorflow.python.util.tf_export import kwarg_only as _kwarg_only
1001 from tensorflow.tools.docs import doc_controls as _doc_controls
1002
1003 )");
1004
1005 // We'll make a copy of ops that filters out descriptions.
1006 OpList cleaned_ops;
1007 auto out = cleaned_ops.mutable_op();
1008 out->Reserve(ops.op_size());
1009 for (const auto& op_def : ops.op()) {
1010 const auto* api_def = api_defs.GetApiDef(op_def.name());
1011
1012 if (api_def->visibility() == ApiDef::SKIP) {
1013 continue;
1014 }
1015 // An op is hidden if either its ApiDef visibility is HIDDEN
1016 // or it is in the hidden_ops list.
1017 bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
1018 bool hidden_by_api_def = is_hidden;
1019 if (!is_hidden) {
1020 for (const string& hidden : hidden_ops) {
1021 if (op_def.name() == hidden) {
1022 is_hidden = true;
1023 break;
1024 }
1025 }
1026 }
1027
1028 string function_name;
1029 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
1030 &function_name);
1031 bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name);
1032
1033 // Prefix an op with underscore if the op is listed in hidden_ops or
1034 // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix.
1035 // Do not add underscores to ops set to HIDDEN in ApiDef otherwise.
1036 // TODO(annarev): don't prefix with underscores even if op is in hidden_ops.
1037 if (is_hidden) {
1038 if (!hidden_by_api_def || is_reserved ||
1039 python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) {
1040 function_name = strings::StrCat("_", function_name);
1041 }
1042 } else if (is_reserved) {
1043 // When users create custom python wrappers, they may link in the
1044 // default op registry by accident, and because they can't
1045 // enumerate all 'hidden' symbols, this guard is to prevent
1046 // instantiating a python reserved word in their wrapper.
1047 continue;
1048 }
1049
1050 strings::StrAppend(&result,
1051 GetEagerPythonOp(op_def, *api_def, function_name));
1052
1053 if (!require_shapes) {
1054 strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
1055 "\")(None)\n\n");
1056 }
1057
1058 auto added = out->Add();
1059 *added = op_def;
1060 RemoveNonDeprecationDescriptionsFromOpDef(added);
1061 }
1062
1063 result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
1064 op_list = _op_def_pb2.OpList()
1065 op_list.ParseFromString(op_list_proto_bytes)
1066 _op_def_registry.register_op_list(op_list)
1067 op_def_lib = _op_def_library.OpDefLibrary()
1068 op_def_lib.add_op_list(op_list)
1069 return op_def_lib
1070 )");
1071
1072 result.append("# ");
1073 auto ops_text = ProtoDebugString(cleaned_ops);
1074 str_util::StripTrailingWhitespace(&ops_text);
1075 result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
1076 result.append("\n");
1077 strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
1078 str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
1079 return result;
1080 }
1081
1082 } // namespace
1083
PrintPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,bool require_shapes,const string & source_file_name)1084 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1085 const std::vector<string>& hidden_ops, bool require_shapes,
1086 const string& source_file_name) {
1087 printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes,
1088 source_file_name)
1089 .c_str());
1090 }
1091
GetPythonWrappers(const char * op_list_buf,size_t op_list_len)1092 string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
1093 string op_list_str(op_list_buf, op_list_len);
1094 OpList ops;
1095 ops.ParseFromString(op_list_str);
1096
1097 ApiDefMap api_def_map(ops);
1098 return GetPythonOps(ops, api_def_map, {}, false);
1099 }
1100
1101 } // namespace tensorflow
1102