1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/framework/python_op_gen.h"
16
17 #include <stdio.h>
18
19 #include <sstream>
20 #include <unordered_map>
21
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/api_def.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/op_gen_lib.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/lib/strings/stringprintf.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/python/framework/python_op_gen_internal.h"
40
41 namespace tensorflow {
42 namespace {
43
44 const int kRightMargin = 78;
45
46 constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
47
48 // Maps C++ dtype enum values to Python DType classes
49 const std::unordered_map<string, string> dtype_type{
50 {"_dtypes.float16", "_dtypes.Float16"},
51 {"_dtypes.half", "_dtypes.Half"},
52 {"_dtypes.float32", "_dtypes.Float32"},
53 {"_dtypes.float64", "_dtypes.Float64"},
54 {"_dtypes.bfloat16", "_dtypes.BFloat16"},
55 {"_dtypes.complex64", "_dtypes.Complex64"},
56 {"_dtypes.complex128", "_dtypes.Complex128"},
57 {"_dtypes.int8", "_dtypes.Int8"},
58 {"_dtypes.uint8", "_dtypes.UInt8"},
59 {"_dtypes.uint16", "_dtypes.UInt16"},
60 {"_dtypes.uint32", "_dtypes.UInt32"},
61 {"_dtypes.uint64", "_dtypes.UInt64"},
62 {"_dtypes.int16", "_dtypes.Int16"},
63 {"_dtypes.int32", "_dtypes.Int32"},
64 {"_dtypes.int64", "_dtypes.Int64"},
65 {"_dtypes.bool", "_dtypes.Bool"},
66 {"_dtypes.string", "_dtypes.String"},
67 {"_dtypes.qint8", "_dtypes.QInt8"},
68 {"_dtypes.quint8", "_dtypes.QUInt8"},
69 {"_dtypes.qint16", "_dtypes.QInt16"},
70 {"_dtypes.quint16", "_dtypes.QUInt16"},
71 {"_dtypes.qint32", "_dtypes.QInt32"},
72 {"_dtypes.resource", "_dtypes.Resource"},
73 {"_dtypes.variant", "_dtypes.Variant"}};
74
AttrVarName(const string & attr_name,std::unordered_map<string,string> * attr_expressions)75 string AttrVarName(const string& attr_name,
76 std::unordered_map<string, string>* attr_expressions) {
77 const string var = strings::StrCat("_attr_", attr_name);
78 if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
79 return var;
80 }
81
AddInferredAttr(const string & indentation,const string & attr_name,const string & value_expression,string * result,std::unordered_map<string,string> * attr_expressions)82 void AddInferredAttr(const string& indentation, const string& attr_name,
83 const string& value_expression, string* result,
84 std::unordered_map<string, string>* attr_expressions) {
85 strings::StrAppend(result, indentation,
86 AttrVarName(attr_name, attr_expressions), " = ",
87 value_expression, "\n");
88 }
89
VectorToTuple(const std::vector<string> & l)90 string VectorToTuple(const std::vector<string>& l) {
91 if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
92 string ret = "(";
93 for (int i = 0, end = l.size(); i < end; ++i) {
94 if (i > 0) {
95 strings::StrAppend(&ret, ", ");
96 }
97 strings::StrAppend(&ret, l[i]);
98 }
99 strings::StrAppend(&ret, ")");
100 return ret;
101 }
102
Unflatten(const string & prefix,const std::vector<string> & output_sizes,const string & var,string * result)103 void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
104 const string& var, string* result) {
105 for (int i = 0, end = output_sizes.size(); i < end; ++i) {
106 if (!output_sizes[i].empty()) {
107 strings::StrAppend(result, prefix, var, " = ");
108 if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
109 if (i + 1 < end) {
110 // Special case i == 0 to avoid "0 +" in the generated code.
111 if (i == 0) {
112 strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
113 var, "[", output_sizes[i], ":]");
114 } else {
115 strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
116 output_sizes[i], "]] + ", var, "[", i, " + ",
117 output_sizes[i], ":]");
118 }
119 } else {
120 strings::StrAppend(result, "[", var, "[", i, ":]]");
121 }
122 strings::StrAppend(result, "\n");
123 }
124 }
125 }
126
TensorPBString(const TensorProto & pb)127 string TensorPBString(const TensorProto& pb) {
128 // Note: This gets used in the argument list, and so must survive naive
129 // word wrapping.
130 return strings::StrCat("\"\"\"", pb.ShortDebugString(), "\"\"\"");
131 }
132
133 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
134 public:
GenEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)135 GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
136 const string& function_name, bool add_type_annotations)
137 : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name,
138 add_type_annotations) {
139 op_name_ = function_name_;
140 absl::ConsumePrefix(&op_name_, "_");
141 }
~GenEagerPythonOp()142 ~GenEagerPythonOp() override {}
143
144 string Code() override;
145
146 protected:
147 void HandleGraphMode(const string& function_setup,
148 const std::vector<string>& output_sizes);
149
150 string GetEagerNotAllowedError();
151 void ExpectListArg(const string& indentation, const string& arg_name,
152 string* output);
153 bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
154 void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
155 string* num_outputs_expr);
156
157 void AddEagerFunctionTeardown(const string& indentation,
158 const std::vector<string>& output_sizes,
159 bool execute_record_gradient);
160
161 bool AddEagerFastPathAndGraphCode(
162 const string& parameters, const std::vector<string>& output_sizes,
163 const string& eager_not_allowed_error,
164 const std::unordered_map<string, string>& type_annotations);
165 bool AddEagerFallbackCode(
166 const string& parameters, const std::vector<string>& output_sizes,
167 const string& num_outputs_expr, const string& eager_not_allowed_error,
168 const std::unordered_map<string, string>& type_annotations);
169 void AddEagerFastPathExecute();
170
171 void AddEagerInferredAttrs(const string& indentation);
172 void AddEagerInputCasts(const string& indentation);
173 void AddEagerAttrs(const string& indentation);
174 void AddEagerExecute(const string& indentation,
175 const string& num_outputs_expr);
176 void AddDispatch(const string& prefix);
177
178 void AddRawOpExport(const string& parameters);
179
180 std::unordered_map<string, string> GetTypeAnnotations();
181
182 void GenerateTypeVars(
183 const std::unordered_map<string, string>& type_annotations);
184
185 void AddReturnTypeAnnotation(
186 const std::unordered_map<string, string>& type_annotations);
187
AddAttrForArg(const string & attr,int arg_index)188 void AddAttrForArg(const string& attr, int arg_index) {
189 gtl::InsertIfNotPresent(&inferred_attrs_, attr,
190 op_def_.input_arg(arg_index).name());
191 auto iter = attr_to_args_.find(attr);
192 if (iter == attr_to_args_.end()) {
193 attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
194 } else {
195 iter->second.push_back(arg_index);
196 }
197 }
198
199 // Returns a string expression representing a flattened list of all
200 // the inputs given by `*input_indices` (or all inputs if
201 // `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
202 string FlattenInputs(const std::vector<int>* input_indices,
203 std::vector<string>* output_sizes) const;
204
205 StringPiece op_name_;
206 typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
207 AttrToArgMap attr_to_args_;
208 std::unordered_map<string, string> attr_expressions_;
209 // This has all the input args followed by those attrs that don't have
210 // defaults.
211 std::vector<python_op_gen_internal::ParamNames> params_no_default_;
212 // The parameters with defaults (these have to be listed after those without).
213 // No input args are included, just attrs.
214 std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
215 params_with_default_;
216 };
217
GetEagerPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name,bool add_type_annotations)218 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
219 const string& function_name,
220 bool add_type_annotations) {
221 return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations)
222 .Code();
223 }
224
FlattenInputs(const std::vector<int> * input_indices,std::vector<string> * output_sizes) const225 string GenEagerPythonOp::FlattenInputs(
226 const std::vector<int>* input_indices,
227 std::vector<string>* output_sizes) const {
228 string inputs;
229 enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
230 const int n = input_indices != nullptr ? input_indices->size()
231 : op_def_.input_arg_size();
232 for (int j = 0; j < n; ++j) {
233 const int i = input_indices ? (*input_indices)[j] : j;
234 const auto& arg(op_def_.input_arg(i));
235 const bool is_list =
236 !arg.type_list_attr().empty() || !arg.number_attr().empty();
237 if (is_list) {
238 if (inputs_state == WAS_SOLO_INPUT) {
239 strings::StrAppend(&inputs, "] + ");
240 } else if (inputs_state == WAS_LIST_INPUT) {
241 strings::StrAppend(&inputs, " + ");
242 }
243 strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
244 inputs_state = WAS_LIST_INPUT;
245 if (output_sizes != nullptr) {
246 if (!arg.number_attr().empty()) {
247 output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
248 } else {
249 output_sizes->emplace_back(
250 strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
251 }
252 }
253 } else {
254 if (inputs_state == WAS_SOLO_INPUT) {
255 strings::StrAppend(&inputs, ", ");
256 } else if (inputs_state == WAS_LIST_INPUT) {
257 strings::StrAppend(&inputs, " + [");
258 } else {
259 strings::StrAppend(&inputs, "[");
260 }
261 strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
262 inputs_state = WAS_SOLO_INPUT;
263 if (output_sizes != nullptr) output_sizes->emplace_back();
264 }
265 }
266 if (inputs_state == STARTING) return "[]";
267 if (inputs_state == WAS_SOLO_INPUT) {
268 strings::StrAppend(&inputs, "]");
269 }
270 return inputs;
271 }
272
Code()273 string GenEagerPythonOp::Code() {
274 if (api_def_.visibility() == ApiDef::SKIP) {
275 return "";
276 }
277
278 for (int i = 0; i < api_def_.arg_order_size(); ++i) {
279 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
280 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
281 params_no_default_.emplace_back(api_def_arg.name(),
282 api_def_arg.rename_to());
283 if (!arg.type_attr().empty()) {
284 AddAttrForArg(arg.type_attr(), i);
285 } else if (!arg.type_list_attr().empty()) {
286 AddAttrForArg(arg.type_list_attr(), i);
287 }
288 if (!arg.number_attr().empty()) {
289 AddAttrForArg(arg.number_attr(), i);
290 }
291 }
292 for (int i = 0; i < op_def_.attr_size(); ++i) {
293 const auto& attr(op_def_.attr(i));
294 const auto& api_def_attr(api_def_.attr(i));
295 // Do not add inferred attrs to the Python function signature.
296 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
297 if (api_def_attr.has_default_value()) {
298 if (attr.type() == "tensor") {
299 params_with_default_.emplace_back(
300 python_op_gen_internal::ParamNames(api_def_attr.name(),
301 api_def_attr.rename_to()),
302 strings::StrCat(
303 "_execute.make_tensor(",
304 TensorPBString(api_def_attr.default_value().tensor()), ", \"",
305 api_def_attr.rename_to(), "\")"));
306 } else if (attr.type() == "list(tensor)") {
307 std::vector<string> pbtxt;
308 for (const auto& pb : api_def_attr.default_value().list().tensor()) {
309 pbtxt.emplace_back(TensorPBString(pb));
310 }
311 params_with_default_.emplace_back(
312 python_op_gen_internal::ParamNames(api_def_attr.name(),
313 api_def_attr.rename_to()),
314 strings::StrCat("[_execute.make_tensor(_pb, \"",
315 api_def_attr.rename_to(), "\") for _pb in ",
316 VectorToTuple(pbtxt), "]"));
317 } else {
318 params_with_default_.emplace_back(
319 python_op_gen_internal::ParamNames(api_def_attr.name(),
320 api_def_attr.rename_to()),
321 python_op_gen_internal::AttrValueToPython(
322 attr.type(), api_def_attr.default_value(), "_dtypes."));
323 }
324 } else {
325 params_no_default_.emplace_back(api_def_attr.name(),
326 api_def_attr.rename_to());
327 }
328 }
329 }
330
331 // Save the list of attr parameters (attrs that won't be inferred),
332 // those with defaults go at the end.
333 // Get the attrs in the order we want by taking the attrs without defaults
334 // from the end of params_no_default_, and adding params_no_default_.
335 attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
336 params_with_default_.size());
337 for (int i = op_def_.input_arg_size(), end = params_no_default_.size();
338 i < end; ++i) {
339 attrs_.push_back(params_no_default_[i].GetName());
340 }
341 for (const auto& p : params_with_default_) {
342 attrs_.push_back(p.first.GetName());
343 }
344
345 // TODO(slebedev): call AvoidPythonReserved on each param?
346 param_names_.reserve(params_no_default_.size() + params_with_default_.size());
347 param_names_.insert(param_names_.begin(), params_no_default_.begin(),
348 params_no_default_.end());
349 for (const auto& param_and_default : params_with_default_) {
350 param_names_.push_back(param_and_default.first);
351 }
352
353 std::unordered_map<string, string> type_annotations;
354 // Only populate map for allowlisted ops
355 if (add_type_annotations_) {
356 type_annotations = GetTypeAnnotations();
357 }
358
359 string parameters;
360 // Param can be an input or an attr
361 for (const auto& param : params_no_default_) {
362 if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
363 strings::StrAppend(¶meters, param.GetRenameTo());
364
365 if (type_annotations.find(param.GetName()) != type_annotations.end()) {
366 strings::StrAppend(¶meters, ": ",
367 type_annotations.at(param.GetName()));
368 }
369 }
370
371 string parameters_with_defaults = parameters;
372 for (const auto& param_and_default : params_with_default_) {
373 if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
374 if (!parameters_with_defaults.empty())
375 strings::StrAppend(¶meters_with_defaults, ", ");
376
377 strings::StrAppend(¶meters, param_and_default.first.GetRenameTo());
378 strings::StrAppend(¶meters_with_defaults,
379 param_and_default.first.GetRenameTo());
380 if (type_annotations.find(param_and_default.first.GetName()) !=
381 type_annotations.end()) {
382 const string param_type =
383 type_annotations.at(param_and_default.first.GetName());
384 // Append to parameters and parameters_with_defaults because multiple
385 // functions are generated by AddEagerFastPathAndGraphCode() and
386 // AddEagerFallbackCode()
387 strings::StrAppend(¶meters, ": ", param_type);
388 strings::StrAppend(¶meters_with_defaults, ":", param_type);
389 }
390
391 strings::StrAppend(¶meters_with_defaults, "=",
392 param_and_default.second);
393 }
394
395 strings::StrAppend(¶meters, parameters.empty() ? "" : ", ", "name");
396 strings::StrAppend(¶meters_with_defaults,
397 parameters_with_defaults.empty() ? "" : ", ", "name=None");
398
399 // Add attr_expressions_ for attrs that are params.
400 for (int i = 0, end = attrs_.size(); i < end; ++i) {
401 const string& attr_name = attrs_[i];
402 const string& attr_api_name =
403 param_names_[i + op_def_.input_arg_size()].GetRenameTo();
404 attr_expressions_[attr_name] = attr_api_name;
405 }
406 // Add attr_expressions_ for attrs that are inferred.
407 for (int i = 0; i < op_def_.attr_size(); ++i) {
408 const auto& attr(op_def_.attr(i));
409 if (attr.type() == "int") {
410 auto arg_list = attr_to_args_.find(attr.name());
411 if (arg_list != attr_to_args_.end()) {
412 AttrVarName(attr.name(), &attr_expressions_);
413 }
414 }
415 }
416
417 string num_outputs_expr;
418 std::vector<string> output_sizes(num_outs_);
419 GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
420
421 string eager_not_allowed_error = GetEagerNotAllowedError();
422
423 if (!AddEagerFastPathAndGraphCode(parameters_with_defaults, output_sizes,
424 eager_not_allowed_error,
425 type_annotations)) {
426 return result_;
427 }
428
429 if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
430 eager_not_allowed_error, type_annotations)) {
431 return result_;
432 }
433
434 return prelude_ + result_;
435 }
436
GetTypeAnnotations()437 std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
438 std::unordered_map<string, string> type_annotations;
439 // Map attrs to TypeVars
440 for (const auto& attr : op_def_.attr()) {
441 if (attr.type() == "type") {
442 const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
443 type_annotations[attr.name()] = type_var_name;
444 } else if (attr.type() == "bool" || attr.type() == "float" ||
445 attr.type() == "int" || attr.type() == "bytes") {
446 type_annotations[attr.name()] = attr.type();
447 } else if (attr.type() == "string") {
448 type_annotations[attr.name()] = "str";
449 }
450 }
451
452 // Map input Tensors to their types
453 for (const auto& arg : op_def_.input_arg()) {
454 // TODO(rahulkamat): Add type annotations to args that accept a sequence of
455 // Tensors
456 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
457 type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
458 }
459
460 // TODO(rahulkamat): Add type annotations to handle return types of a sequence
461 // of Tensors. Map output Tensor to its type
462 if (op_def_.output_arg_size() == 1) {
463 const auto& arg = op_def_.output_arg(0);
464 if (arg.number_attr().empty() && arg.type_list_attr().empty())
465 type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
466 }
467
468 return type_annotations;
469 }
470
471 // Generate TypeVars using attrs
GenerateTypeVars(const std::unordered_map<string,string> & type_annotations)472 void GenEagerPythonOp::GenerateTypeVars(
473 const std::unordered_map<string, string>& type_annotations) {
474 bool added_typevar = false;
475 for (const auto& attr : op_def_.attr()) {
476 if (attr.type() == "type") {
477 std::vector<string> allowed_types;
478 for (int t : attr.allowed_values().list().type()) {
479 DataType dtype = static_cast<DataType>(t);
480 const string py_dtype =
481 python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
482 allowed_types.emplace_back(dtype_type.at(py_dtype));
483 }
484
485 // When a Tensor does not have any dtypes specified, all dtypes are
486 // allowed
487 if (allowed_types.empty()) {
488 for (std::pair<string, string> map_dtype : dtype_type) {
489 allowed_types.emplace_back(map_dtype.second);
490 }
491 }
492
493 std::sort(allowed_types.begin(), allowed_types.end());
494
495 string typevar_dtypes;
496 for (std::vector<string>::iterator it = allowed_types.begin();
497 it != allowed_types.end(); ++it) {
498 if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", ");
499 strings::StrAppend(&typevar_dtypes, *it);
500 }
501
502 const string type_var_name = type_annotations.at(attr.name());
503 strings::StrAppend(&result_, type_var_name, " = TypeVar(\"",
504 type_var_name, "\", ", typevar_dtypes, ")\n");
505 added_typevar = true;
506 }
507 }
508
509 if (added_typevar) strings::StrAppend(&result_, "\n");
510 }
511
AddReturnTypeAnnotation(const std::unordered_map<string,string> & type_annotations)512 void GenEagerPythonOp::AddReturnTypeAnnotation(
513 const std::unordered_map<string, string>& type_annotations) {
514 if (op_def_.output_arg_size() == 1) {
515 const auto& arg = op_def_.output_arg(0);
516 if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
517 const string return_type = type_annotations.at(arg.name());
518 // TODO(rahulkamat): Modify AddDefLine() to add return type annotation to
519 // avoid erasing ":\n" from the end of the def line
520 result_.erase(result_.length() - 2);
521 strings::StrAppend(&result_, " -> ", return_type, ":\n");
522 }
523 }
524 }
525
HandleGraphMode(const string & function_setup,const std::vector<string> & output_sizes)526 void GenEagerPythonOp::HandleGraphMode(
527 const string& function_setup, const std::vector<string>& output_sizes) {
528 strings::StrAppend(&result_, " # Add nodes to the TensorFlow graph.\n");
529 strings::StrAppend(&result_, function_setup);
530 if (api_def_.visibility() == ApiDef::VISIBLE) {
531 strings::StrAppend(&result_, " try:\n ");
532 }
533 strings::StrAppend(
534 &result_, " _, _, _op, _outputs = _op_def_library._apply_op_helper(\n");
535 AddBodyNoReturn(strings::StrCat(" \"", op_def_.name(), "\", "));
536 AddDispatch(" ");
537
538 if (num_outs_ > 0) {
539 strings::StrAppend(&result_, " _result = _outputs[:]\n");
540 // Special case handling for stateful op with single list output
541 // that might be empty.
542 if (num_outs_ == 1 && op_def_.is_stateful() &&
543 (!op_def_.output_arg(0).number_attr().empty() ||
544 !op_def_.output_arg(0).type_list_attr().empty())) {
545 // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
546 // a constraint indicating that this can never be empty.
547 strings::StrAppend(&result_,
548 " if not _result:\n"
549 " return _op\n");
550 }
551
552 // Compute graph-mode attrs when we need to record a gradient.
553 strings::StrAppend(&result_, " if _execute.must_record_gradient():\n");
554 if (op_def_.attr_size() > 0) {
555 string attr_values;
556 for (int i = 0; i < op_def_.attr_size(); ++i) {
557 if (i > 0) strings::StrAppend(&attr_values, ", ");
558 const auto& attr_name(op_def_.attr(i).name());
559 if (op_def_.attr(i).type() == "type") {
560 strings::StrAppend(&attr_values, "\"", attr_name,
561 "\", _op._get_attr_type(\"", attr_name, "\")");
562 } else if (op_def_.attr(i).type() == "bool") {
563 strings::StrAppend(&attr_values, "\"", attr_name,
564 "\", _op._get_attr_bool(\"", attr_name, "\")");
565 } else if (op_def_.attr(i).type() == "int") {
566 strings::StrAppend(&attr_values, "\"", attr_name,
567 "\", _op._get_attr_int(\"", attr_name, "\")");
568 } else {
569 strings::StrAppend(&attr_values, "\"", attr_name,
570 "\", _op.get_attr(\"", attr_name, "\")");
571 }
572 }
573 strings::StrAppend(&attr_values, ")");
574 strings::StrAppend(&result_,
575 WordWrap(" _attrs = (", attr_values, kRightMargin),
576 "\n");
577
578 } else {
579 strings::StrAppend(&result_, " _attrs = ()\n");
580 }
581
582 strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n");
583 strings::StrAppend(&result_, " _execute.record_gradient(\n",
584 " \"", op_def_.name(),
585 "\", _inputs_flat, _attrs, _result)\n");
586
587 if (num_outs_ == 1 && !output_sizes[0].empty()) {
588 // Single list result.
589 } else if (num_outs_ == 1) {
590 // Execute returns a single-element list which we need to destructure.
591 strings::StrAppend(&result_, " ", "_result, = _result\n");
592 } else {
593 // Have multiple outputs, so we will need to reformat the return
594 // value of execute() to be a list with one entry per op output
595 // (that entry will be a list of tensors if that output is of list
596 // type).
597 // For list outputs, convert the right subrange of _result into a list.
598 Unflatten(" ", output_sizes, "_result", &result_);
599 // Convert to a named tuple.
600 strings::StrAppend(
601 &result_, " _result = _",
602 python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
603 "Output._make(_result)\n");
604 }
605 strings::StrAppend(&result_, " return _result\n\n");
606 } else {
607 strings::StrAppend(&result_, " return _op\n");
608 }
609 }
610
GetEagerNotAllowedError()611 string GenEagerPythonOp::GetEagerNotAllowedError() {
612 bool eager_allowed = true;
613 string ref_arg;
614 for (int i = 0; i < op_def_.input_arg_size(); ++i) {
615 const auto& arg = op_def_.input_arg(i);
616 if (arg.is_ref()) {
617 eager_allowed = false;
618 DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
619 ref_arg = api_def_.in_arg(i).rename_to();
620 }
621 }
622 for (int i = 0; i < op_def_.output_arg_size(); ++i) {
623 const auto& arg = op_def_.output_arg(i);
624 if (arg.is_ref()) {
625 eager_allowed = false;
626 DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
627 ref_arg = api_def_.out_arg(i).rename_to();
628 }
629 }
630
631 if (eager_allowed) return "";
632
633 return strings::StrCat("raise RuntimeError(\"", op_name_,
634 " op does not support eager execution. ", "Arg '",
635 ref_arg, "' is a ref.\")\n");
636 }
637
ExpectListArg(const string & indentation,const string & arg_name,string * output)638 void GenEagerPythonOp::ExpectListArg(const string& indentation,
639 const string& arg_name, string* output) {
640 strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
641 ", (list, tuple)):\n", indentation, " raise TypeError(\n",
642 indentation, " \"Expected list for '", arg_name,
643 "' argument to \"\n", indentation, " \"'", op_name_,
644 "' Op, not %r.\" % ", arg_name, ")\n");
645 }
646
GetEagerFunctionSetup(const string & indentation,string * function_setup)647 bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
648 string* function_setup) {
649 // Validate list inputs, infer length attrs.
650 for (int i = 0; i < op_def_.attr_size(); ++i) {
651 const auto& attr(op_def_.attr(i));
652 if (attr.type() == "int") {
653 auto arg_list = attr_to_args_.find(attr.name());
654 if (arg_list != attr_to_args_.end()) {
655 // Inferred int attrs are the lengths of inputs. Validate those
656 // inputs are lists and have the same length.
657 for (auto iter = arg_list->second.begin();
658 iter != arg_list->second.end(); ++iter) {
659 const string& arg_api_name = param_names_[*iter].GetRenameTo();
660 ExpectListArg(indentation, arg_api_name, function_setup);
661 if (iter == arg_list->second.begin()) {
662 AddInferredAttr(indentation, attr.name(),
663 strings::StrCat("len(", arg_api_name, ")"),
664 function_setup, &attr_expressions_);
665 } else {
666 const auto& attr_var = attr_expressions_[attr.name()];
667 strings::StrAppend(
668 function_setup, indentation, "if len(", arg_api_name,
669 ") != ", attr_var, ":\n", indentation, " raise ValueError(\n",
670 indentation, " \"List argument '", arg_api_name, "' to '",
671 op_name_, "' Op with length %d \"\n", indentation,
672 " \"must match length %d of argument '",
673 inferred_attrs_[attr.name()], "'.\" %\n", indentation,
674 " (len(", arg_api_name, "), ", attr_var, "))\n");
675 }
676 }
677 }
678 }
679 }
680
681 for (int i = 0, end = attrs_.size(); i < end; ++i) {
682 const string& attr_name = attrs_[i];
683 const auto& param = param_names_[i + op_def_.input_arg_size()];
684 const auto& attr = *FindAttr(attr_name, op_def_);
685 const string& attr_api_name = param.GetRenameTo();
686 StringPiece attr_type = attr.type();
687 attr_expressions_[attr_name] = attr_api_name;
688 const int default_index = i - (attrs_.size() - params_with_default_.size());
689 if (default_index >= 0) {
690 const string& default_value = params_with_default_[default_index].second;
691 strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
692 " is None:\n");
693 strings::StrAppend(function_setup, indentation, " ", attr_api_name,
694 " = ", default_value, "\n");
695 }
696 if (absl::StartsWith(attr_type, "list(")) {
697 ExpectListArg(indentation, attr_api_name, function_setup);
698 }
699
700 if (attr_type == "string") {
701 strings::StrAppend(function_setup, indentation, attr_api_name,
702 " = _execute.make_str(", attr_api_name, ", \"",
703 attr_api_name, "\")\n");
704 } else if (attr_type == "list(string)") {
705 strings::StrAppend(function_setup, indentation, attr_api_name,
706 " = [_execute.make_str(_s, \"", attr_api_name,
707 "\") for _s in ", attr_api_name, "]\n");
708 } else if (attr_type == "int") {
709 strings::StrAppend(function_setup, indentation, attr_api_name,
710 " = _execute.make_int(", attr_api_name, ", \"",
711 attr_api_name, "\")\n");
712 } else if (attr_type == "list(int)") {
713 strings::StrAppend(function_setup, indentation, attr_api_name,
714 " = [_execute.make_int(_i, \"", attr_api_name,
715 "\") for _i in ", attr_api_name, "]\n");
716 } else if (attr_type == "float") {
717 strings::StrAppend(function_setup, indentation, attr_api_name,
718 " = _execute.make_float(", attr_api_name, ", \"",
719 attr_api_name, "\")\n");
720 } else if (attr_type == "list(float)") {
721 strings::StrAppend(function_setup, indentation, attr_api_name,
722 " = [_execute.make_float(_f, \"", attr_api_name,
723 "\") for _f in ", attr_api_name, "]\n");
724 } else if (attr_type == "bool") {
725 strings::StrAppend(function_setup, indentation, attr_api_name,
726 " = _execute.make_bool(", attr_api_name, ", \"",
727 attr_api_name, "\")\n");
728 } else if (attr_type == "list(bool)") {
729 strings::StrAppend(function_setup, indentation, attr_api_name,
730 " = [_execute.make_bool(_b, \"", attr_api_name,
731 "\") for _b in ", attr_api_name, "]\n");
732 } else if (attr_type == "type") {
733 strings::StrAppend(function_setup, indentation, attr_api_name,
734 " = _execute.make_type(", attr_api_name, ", \"",
735 attr_api_name, "\")\n");
736 } else if (attr_type == "list(type)") {
737 strings::StrAppend(function_setup, indentation, attr_api_name,
738 " = [_execute.make_type(_t, \"", attr_api_name,
739 "\") for _t in ", attr_api_name, "]\n");
740 } else if (attr_type == "shape") {
741 strings::StrAppend(function_setup, indentation, attr_api_name,
742 " = _execute.make_shape(", attr_api_name, ", \"",
743 attr_api_name, "\")\n");
744 } else if (attr_type == "list(shape)") {
745 strings::StrAppend(function_setup, indentation, attr_api_name,
746 " = [_execute.make_shape(_s, \"", attr_api_name,
747 "\") for _s in ", attr_api_name, "]\n");
748 } else if (attr_type == "tensor") {
749 strings::StrAppend(function_setup, indentation, attr_api_name,
750 " = _execute.make_tensor(", attr_api_name, ", \"",
751 attr_api_name, "\")\n");
752 } else if (attr_type == "list(tensor)") {
753 strings::StrAppend(function_setup, indentation, attr_api_name,
754 " = [_execute.make_tensor(_t, \"", attr_api_name,
755 "\") for _t in ", attr_api_name, "]\n");
756 } else if (attr_type != "func" && attr_type != "list(func)") {
757 *function_setup =
758 strings::StrCat("# No definition for ", function_name_,
759 " since we don't support attrs with type\n"
760 "# '",
761 attr_type, "' right now.\n\n");
762 return false;
763 }
764 }
765 return true;
766 }
767
768 // If output i is list output, output_sizes[i] will be set to a
769 // string with the python expression that will evaluate to its
770 // length. output_sizes[i] is empty for non-list outputs.
GetOutputSizesAndNumOutputsExpr(std::vector<string> * output_sizes,string * num_outputs_expr)771 void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
772 std::vector<string>* output_sizes, string* num_outputs_expr) {
773 // Expression representing the number of outputs.
774 int num_fixed_outputs = 0;
775 for (int i = 0; i < num_outs_; ++i) {
776 const auto& arg(op_def_.output_arg(i));
777 if (!arg.number_attr().empty()) {
778 if (!num_outputs_expr->empty()) {
779 strings::StrAppend(num_outputs_expr, " + ");
780 }
781 (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
782 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
783 } else if (!arg.type_list_attr().empty()) {
784 if (!num_outputs_expr->empty()) {
785 strings::StrAppend(num_outputs_expr, " + ");
786 }
787 // Have to be careful to use an expression that works in both
788 // graph and eager paths here.
789 const auto iter = inferred_attrs_.find(arg.type_list_attr());
790 if (iter == inferred_attrs_.end()) {
791 (*output_sizes)[i] = strings::StrCat(
792 "len(", attr_expressions_[arg.type_list_attr()], ")");
793 } else {
794 (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
795 }
796 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
797 } else {
798 ++num_fixed_outputs;
799 }
800 }
801 if (num_fixed_outputs > 0) {
802 if (!num_outputs_expr->empty()) {
803 strings::StrAppend(num_outputs_expr, " + ");
804 }
805 strings::StrAppend(num_outputs_expr, num_fixed_outputs);
806 } else if (num_outputs_expr->empty()) {
807 *num_outputs_expr = "0";
808 }
809 }
810
AddEagerFunctionTeardown(const string & indentation,const std::vector<string> & output_sizes,bool execute_record_gradient)811 void GenEagerPythonOp::AddEagerFunctionTeardown(
812 const string& indentation, const std::vector<string>& output_sizes,
813 bool execute_record_gradient) {
814 if (num_outs_ > 0) {
815 if (execute_record_gradient) {
816 strings::StrAppend(&result_, indentation,
817 "if _execute.must_record_gradient():\n");
818 strings::StrAppend(&result_, indentation, " _execute.record_gradient(\n",
819 " \"", op_def_.name(),
820 "\", _inputs_flat, _attrs, _result)\n");
821 }
822 if (num_outs_ == 1 && !output_sizes[0].empty()) {
823 // Single list result.
824 } else if (num_outs_ == 1) {
825 // Execute returns a single-element list which we need to destructure.
826 strings::StrAppend(&result_, indentation, "_result, = _result\n");
827 } else {
828 // Have multiple outputs, so we will need to reformat the return
829 // value of execute() to be a list with one entry per op output
830 // (that entry will be a list of tensors if that output is of list
831 // type).
832 // For list outputs, convert the right subrange of _result into a list.
833 Unflatten(indentation, output_sizes, "_result", &result_);
834 // Convert to a named tuple.
835 strings::StrAppend(
836 &result_, indentation, "_result = _",
837 python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
838 "Output._make(_result)\n");
839 }
840 } else {
841 strings::StrAppend(&result_, indentation, "_result = None\n");
842 }
843 strings::StrAppend(&result_, indentation, "return _result\n\n");
844 }
845
AddEagerFastPathAndGraphCode(const string & parameters,const std::vector<string> & output_sizes,const string & eager_not_allowed_error,const std::unordered_map<string,string> & type_annotations)846 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
847 const string& parameters, const std::vector<string>& output_sizes,
848 const string& eager_not_allowed_error,
849 const std::unordered_map<string, string>& type_annotations) {
850 if (add_type_annotations_) {
851 GenerateTypeVars(type_annotations);
852 }
853 if (api_def_.visibility() == ApiDef::VISIBLE) {
854 strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
855 }
856
857 AddExport();
858 AddDefLine(function_name_, parameters);
859 if (add_type_annotations_) {
860 AddReturnTypeAnnotation(type_annotations);
861 }
862 AddDocStringDescription();
863 AddDocStringArgs();
864 AddDocStringInputs();
865 AddDocStringAttrs();
866 AddDocStringNameArg();
867 AddOutputGlobals(); // Added to prelude_
868 AddDocStringOutputs();
869 strings::StrAppend(&result_, " \"\"\"\n");
870
871 strings::StrAppend(&result_,
872 " _ctx = _context._context or _context.context()\n"
873 " tld = _ctx._thread_local_data\n",
874 " if tld.is_eager:", "\n");
875 if (eager_not_allowed_error.empty()) {
876 AddEagerFastPathExecute();
877 } else {
878 strings::StrAppend(&result_, " ", eager_not_allowed_error);
879 }
880
881 // Handle graph-mode case
882 string function_setup;
883 if (!GetEagerFunctionSetup(" ", &function_setup)) {
884 result_ = function_setup;
885 return false;
886 }
887 HandleGraphMode(function_setup, output_sizes);
888
889 AddRawOpExport(parameters);
890 strings::StrAppend(&result_, "\n\n");
891 return true;
892 }
893
AddEagerFallbackCode(const string & parameters,const std::vector<string> & output_sizes,const string & num_outputs_expr,const string & eager_not_allowed_error,const std::unordered_map<string,string> & type_annotations)894 bool GenEagerPythonOp::AddEagerFallbackCode(
895 const string& parameters, const std::vector<string>& output_sizes,
896 const string& num_outputs_expr, const string& eager_not_allowed_error,
897 const std::unordered_map<string, string>& type_annotations) {
898 AddDefLine(
899 strings::StrCat(function_name_, kEagerFallbackSuffix),
900 strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
901 if (add_type_annotations_) {
902 AddReturnTypeAnnotation(type_annotations);
903 }
904 if (!eager_not_allowed_error.empty()) {
905 strings::StrAppend(&result_, " ", eager_not_allowed_error);
906 return true;
907 }
908
909 string function_setup;
910 if (!GetEagerFunctionSetup(" ", &function_setup)) {
911 result_ = function_setup;
912 return false;
913 }
914 strings::StrAppend(&result_, function_setup);
915
916 AddEagerInferredAttrs(" ");
917 AddEagerInputCasts(" ");
918 strings::StrAppend(
919 &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
920 AddEagerAttrs(" ");
921 AddEagerExecute(" ", num_outputs_expr);
922
923 AddEagerFunctionTeardown(" ", output_sizes,
924 true /* execute_record_gradient */);
925
926 return true;
927 }
928
AddEagerFastPathExecute()929 void GenEagerPythonOp::AddEagerFastPathExecute() {
930 string fastpath_execute_params =
931 strings::StrCat("_ctx, \"", op_def_.name(), "\", ", "name");
932 string fallback_params;
933
934 for (int i = 0; i < api_def_.in_arg_size(); i++) {
935 const string param_name = param_names_[i].GetRenameTo();
936 strings::StrAppend(&fastpath_execute_params, ", ", param_name);
937 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
938 strings::StrAppend(&fallback_params, param_name);
939 }
940
941 for (const auto& attr : api_def_.attr()) {
942 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
943 strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
944 attr.rename_to());
945
946 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
947 strings::StrAppend(&fallback_params, attr.rename_to(), "=",
948 attr.rename_to());
949 }
950 }
951
952 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
953 strings::StrAppend(&fallback_params, "name=name");
954
955 strings::StrAppend(&result_, " try:\n");
956 strings::StrAppend(
957 &result_, " ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
958 WordWrap(strings::StrCat(" "),
959 strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
960 "\n");
961
962 if (op_def_.output_arg_size() > 1) {
963 const string output_tuple_name = strings::StrCat(
964 "_", python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
965 "Output");
966 strings::StrAppend(&result_, " ", "_result = ", output_tuple_name,
967 "._make(_result)\n");
968 }
969 strings::StrAppend(&result_, " ", "return _result\n");
970
971 // Handle fallback.
972 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
973 strings::StrAppend(&fallback_params, "ctx=_ctx");
974
975 // Any errors thrown from execute need to be unwrapped from
976 // _NotOkStatusException.
977 strings::StrAppend(&result_, " ",
978 "except _core._NotOkStatusException as e:\n");
979 strings::StrAppend(&result_, " ",
980 "_ops.raise_from_not_ok_status(e, name)\n");
981
982 strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
983 strings::StrAppend(&result_, " pass\n");
984 strings::StrAppend(&result_, " try:\n");
985 strings::StrAppend(
986 &result_, " ", "return ", function_name_, kEagerFallbackSuffix,
987 "(\n",
988 WordWrap(strings::StrCat(" "),
989 strings::StrCat(fallback_params, ")"), kRightMargin),
990 "\n");
991 strings::StrAppend(&result_, " except _core._SymbolicException:\n");
992 strings::StrAppend(&result_,
993 " pass # Add nodes to the TensorFlow graph.\n");
994 AddDispatch(" ");
995 }
996
AddEagerInferredAttrs(const string & indentation)997 void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
998 // Figure out values for inferred attrs, and cast to eager tensors.
999 for (int i = 0; i < op_def_.attr_size(); ++i) {
1000 const auto& attr(op_def_.attr(i));
1001 const auto& api_def_attr(api_def_.attr(i));
1002 auto arg_list = attr_to_args_.find(attr.name());
1003 if (arg_list != attr_to_args_.end()) {
1004 if (attr.type() == "type") {
1005 std::vector<string> output_sizes;
1006 const string flattened =
1007 FlattenInputs(&arg_list->second, &output_sizes);
1008 string conversion = strings::StrCat("_execute.args_to_matching_eager(",
1009 flattened, ", ctx");
1010
1011 strings::StrAppend(&conversion, ", [");
1012 for (int t : attr.allowed_values().list().type()) {
1013 DataType dtype = static_cast<DataType>(t);
1014 const string py_dtype =
1015 python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
1016 strings::StrAppend(&conversion, py_dtype, ", ");
1017 }
1018 strings::StrAppend(&conversion, "]");
1019
1020 if (attr.has_default_value()) {
1021 strings::StrAppend(
1022 &conversion, ", ",
1023 python_op_gen_internal::AttrValueToPython(
1024 attr.type(), api_def_attr.default_value(), "_dtypes."));
1025 }
1026 strings::StrAppend(&conversion, ")");
1027 const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1028 if (output_sizes.size() == 1) {
1029 // Avoid creating a temporary variable in the case where
1030 // we can easily assign to the right value directly.
1031 const string inputs_var =
1032 param_names_[arg_list->second.front()].GetRenameTo();
1033 if (output_sizes.front().empty()) {
1034 strings::StrAppend(&result_, indentation, var_name, ", (",
1035 inputs_var, ",) = ", conversion, "\n");
1036 } else {
1037 strings::StrAppend(&result_, indentation, var_name, ", ",
1038 inputs_var, " = ", conversion, "\n");
1039 }
1040 } else {
1041 const string inputs_var = strings::StrCat("_inputs_", attr.name());
1042 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1043 " = ", conversion, "\n");
1044 // Convert from a flat list of eager tensors back to the
1045 // parameter variables.
1046 Unflatten(indentation, output_sizes, inputs_var, &result_);
1047 std::vector<string> p;
1048 for (int j : arg_list->second) {
1049 p.emplace_back(param_names_[j].GetRenameTo());
1050 }
1051 strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
1052 inputs_var, "\n");
1053 }
1054 } else if (attr.type() == "list(type)") {
1055 // NOTE: We ignore default values for these attrs, since it is
1056 // unclear how you would use it, and the one use case is
1057 // parse_single_sequence_example which only needs it for
1058 // backwards compatibility.
1059 const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1060 string inputs_var;
1061 string conversion;
1062 if (arg_list->second.size() > 1) {
1063 // If you have more than one list(tensor) argument, their types
1064 // have to match.
1065 std::vector<string> lists;
1066 for (auto iter = arg_list->second.begin();
1067 iter != arg_list->second.end(); ++iter) {
1068 lists.push_back(param_names_[*iter].GetRenameTo());
1069 }
1070 inputs_var = VectorToTuple(lists);
1071 conversion = "_execute.args_to_mixed_eager_tensors";
1072 } else {
1073 // For one list(tensor) argument, we just convert every
1074 // element of the list to an eager tensor.
1075 inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
1076 conversion = "_execute.convert_to_mixed_eager_tensors";
1077 }
1078 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1079 " = ", conversion, "(", inputs_var, ", ctx)\n");
1080 }
1081 }
1082 }
1083 }
1084
AddEagerInputCasts(const string & indentation)1085 void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
1086 // Cast remaining args to eager tensors
1087 for (int i = 0; i < op_def_.input_arg_size(); ++i) {
1088 const auto& arg(op_def_.input_arg(i));
1089 if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
1090 const string& param = param_names_[i].GetRenameTo();
1091 const string fn = arg.number_attr().empty() ? "" : "n_";
1092 const string dtype =
1093 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1094 strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
1095 "to_tensor(", param, ", ", dtype, ")\n");
1096 }
1097 }
1098
AddEagerAttrs(const string & indentation)1099 void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
1100 // Compute eager attrs
1101 if (op_def_.attr_size() > 0) {
1102 string attr_values;
1103 for (int i = 0; i < op_def_.attr_size(); ++i) {
1104 if (i > 0) strings::StrAppend(&attr_values, ", ");
1105 const auto& attr_name(op_def_.attr(i).name());
1106 strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
1107 attr_expressions_[attr_name]);
1108 }
1109 strings::StrAppend(&attr_values, ")");
1110 strings::StrAppend(
1111 &result_,
1112 WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
1113 kRightMargin),
1114 "\n");
1115 } else {
1116 strings::StrAppend(&result_, indentation, "_attrs = None\n");
1117 }
1118 }
1119
AddEagerExecute(const string & indentation,const string & num_outputs_expr)1120 void GenEagerPythonOp::AddEagerExecute(const string& indentation,
1121 const string& num_outputs_expr) {
1122 const string return_prefix =
1123 strings::StrCat(indentation, "_result = _execute.execute(");
1124 const string return_args = strings::StrCat(
1125 "b\"", op_def_.name(), "\", ", num_outputs_expr,
1126 ", inputs=_inputs_flat, attrs=_attrs, ctx=ctx, name=name)");
1127 strings::StrAppend(&result_,
1128 // Wrap the arguments, and indent to the (.
1129 WordWrap(return_prefix, return_args, kRightMargin), "\n");
1130 }
1131
AddDispatch(const string & prefix)1132 void GenEagerPythonOp::AddDispatch(const string& prefix) {
1133 if (api_def_.visibility() != ApiDef::VISIBLE) return;
1134
1135 strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
1136 strings::StrAppend(&result_, prefix, " result = _dispatch.dispatch(\n");
1137 AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_,
1138 ", "
1139 "(), dict("));
1140 strings::StrAppend(&result_, prefix, " )\n");
1141 strings::StrAppend(&result_, prefix,
1142 " if result is not "
1143 "_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
1144 strings::StrAppend(&result_, prefix, " return result\n");
1145 strings::StrAppend(&result_, prefix, " raise\n");
1146 }
1147
AddRawOpExport(const string & parameters)1148 void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
1149 // Example:
1150 //
1151 // Identity = tf_export("raw_ops.Identity")(_ops._to_raw_op(identity))
1152 const string raw_function_name =
1153 python_op_gen_internal::AvoidPythonReserved(op_def_.name());
1154 strings::StrAppend(&result_, raw_function_name, " = tf_export(\"raw_ops.",
1155 raw_function_name, "\")", "(_ops.to_raw_op(",
1156 function_name_, "))\n");
1157 }
1158
GetPythonOpsImpl(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name="",const std::unordered_set<string> type_annotate_ops={})1159 string GetPythonOpsImpl(
1160 const OpList& ops, const ApiDefMap& api_defs,
1161 const std::vector<string>& hidden_ops, const string& source_file_name = "",
1162 const std::unordered_set<string> type_annotate_ops = {}) {
1163 string result;
1164 // Header
1165 // TODO(josh11b): Mention the library for which wrappers are being generated.
1166 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
1167
1168 This file is MACHINE GENERATED! Do not edit.
1169 )");
1170
1171 // Mention the original source file so someone tracing back through
1172 // generated Python code will know where to look next.
1173 if (!source_file_name.empty()) {
1174 strings::StrAppend(&result, "Original C++ source file: ");
1175 strings::StrAppend(&result, source_file_name);
1176 strings::StrAppend(&result, "\n");
1177 }
1178
1179 strings::StrAppend(&result, R"("""
1180
1181 import collections
1182
1183 from tensorflow.python import pywrap_tfe as pywrap_tfe
1184 from tensorflow.python.eager import context as _context
1185 from tensorflow.python.eager import core as _core
1186 from tensorflow.python.eager import execute as _execute
1187 from tensorflow.python.framework import dtypes as _dtypes
1188
1189 from tensorflow.python.framework import op_def_registry as _op_def_registry
1190 from tensorflow.python.framework import ops as _ops
1191 from tensorflow.python.framework import op_def_library as _op_def_library
1192 from tensorflow.python.util.deprecation import deprecated_endpoints
1193 from tensorflow.python.util import dispatch as _dispatch
1194 from tensorflow.python.util.tf_export import tf_export
1195
1196 from typing import TypeVar
1197 )");
1198
1199 for (const auto& op_def : ops.op()) {
1200 const auto* api_def = api_defs.GetApiDef(op_def.name());
1201
1202 if (api_def->visibility() == ApiDef::SKIP) {
1203 continue;
1204 }
1205 // An op is hidden if either its ApiDef visibility is HIDDEN
1206 // or it is in the hidden_ops list.
1207 bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
1208 bool hidden_by_api_def = is_hidden;
1209 if (!is_hidden) {
1210 for (const string& hidden : hidden_ops) {
1211 if (op_def.name() == hidden) {
1212 is_hidden = true;
1213 break;
1214 }
1215 }
1216 }
1217
1218 string function_name;
1219 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
1220 &function_name);
1221 bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name);
1222
1223 // Prefix an op with underscore if the op is listed in hidden_ops or
1224 // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix.
1225 // Do not add underscores to ops set to HIDDEN in ApiDef otherwise.
1226 // TODO(annarev): don't prefix with underscores even if op is in hidden_ops.
1227 if (is_hidden) {
1228 if (!hidden_by_api_def || is_reserved ||
1229 python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) {
1230 function_name = strings::StrCat("_", function_name);
1231 }
1232 } else if (is_reserved) {
1233 // When users create custom python wrappers, they may link in the
1234 // default op registry by accident, and because they can't
1235 // enumerate all 'hidden' symbols, this guard is to prevent
1236 // instantiating a python reserved word in their wrapper.
1237 continue;
1238 }
1239
1240 auto iter = type_annotate_ops.find(op_def.name());
1241 bool add_type_annotations = iter != type_annotate_ops.end();
1242
1243 strings::StrAppend(&result,
1244 GetEagerPythonOp(op_def, *api_def, function_name,
1245 add_type_annotations));
1246 }
1247
1248 return result;
1249 }
1250
1251 } // namespace
1252
GetPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name,const std::unordered_set<string> type_annotate_ops)1253 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1254 const std::vector<string>& hidden_ops,
1255 const string& source_file_name,
1256 const std::unordered_set<string> type_annotate_ops) {
1257 return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1258 type_annotate_ops);
1259 }
1260
PrintPythonOps(const OpList & ops,const ApiDefMap & api_defs,const std::vector<string> & hidden_ops,const string & source_file_name,const std::unordered_set<string> type_annotate_ops)1261 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1262 const std::vector<string>& hidden_ops,
1263 const string& source_file_name,
1264 const std::unordered_set<string> type_annotate_ops) {
1265 printf("%s", GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1266 type_annotate_ops)
1267 .c_str());
1268 }
1269
GetPythonWrappers(const char * op_list_buf,size_t op_list_len)1270 string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
1271 OpList ops;
1272 ops.ParseFromArray(op_list_buf, op_list_len);
1273
1274 ApiDefMap api_def_map(ops);
1275 return GetPythonOpsImpl(ops, api_def_map, {});
1276 }
1277
GetArgAnnotation(const OpDef::ArgDef & arg,const std::unordered_map<string,string> & type_annotations)1278 string GetArgAnnotation(
1279 const OpDef::ArgDef& arg,
1280 const std::unordered_map<string, string>& type_annotations) {
1281 if (!arg.type_attr().empty()) {
1282 // Get the correct TypeVar if arg maps to an attr
1283 return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
1284 } else {
1285 // Get the dtype of the Tensor
1286 const string py_dtype =
1287 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1288 return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
1289 }
1290
1291 return "Any";
1292 }
1293
1294 } // namespace tensorflow
1295