• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstring>
17 #include <thread>
18 
19 #include "tensorflow/python/eager/pywrap_tfe.h"
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/c/c_api.h"
24 #include "tensorflow/c/c_api_internal.h"
25 #include "tensorflow/c/eager/c_api_internal.h"
26 #include "tensorflow/c/eager/tape.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/gtl/compactptrset.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/python/eager/pywrap_tensor.h"
38 #include "tensorflow/python/lib/core/safe_ptr.h"
39 #include "tensorflow/python/util/util.h"
40 
41 using tensorflow::string;
42 using tensorflow::strings::Printf;
43 
44 namespace {
45 
46 struct InputInfo {
InputInfo__anon177bbf5e0111::InputInfo47   InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
48 
49   int i;
50   bool is_list = false;
51 };
52 
53 // Takes in output gradients, returns input gradients.
54 typedef std::function<PyObject*(PyObject*)> PyBackwardFunction;
55 
56 using AttrToInputsMap =
57     tensorflow::gtl::FlatMap<string,
58                              tensorflow::gtl::InlinedVector<InputInfo, 4>>;
59 
60 tensorflow::mutex all_attr_to_input_maps_lock(tensorflow::LINKER_INITIALIZED);
GetAllAttrToInputsMaps()61 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
62   static auto* all_attr_to_input_maps =
63       new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
64   return all_attr_to_input_maps;
65 }
66 
GetAttrToInputsMap(const tensorflow::OpDef & op_def)67 AttrToInputsMap* GetAttrToInputsMap(const tensorflow::OpDef& op_def) {
68   tensorflow::mutex_lock l(all_attr_to_input_maps_lock);
69   auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
70 
71   auto* output =
72       tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
73   if (output != nullptr) {
74     return output;
75   }
76 
77   std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
78 
79   // Store a list of InputIndex -> List of corresponding inputs.
80   for (int i = 0; i < op_def.input_arg_size(); i++) {
81     if (!op_def.input_arg(i).type_attr().empty()) {
82       auto it = m->find(op_def.input_arg(i).type_attr());
83       if (it == m->end()) {
84         it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
85       }
86       it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
87     }
88   }
89 
90   auto* retval = m.get();
91   (*all_attr_to_input_maps)[op_def.name()] = m.release();
92 
93   return retval;
94 }
95 
96 struct FastPathOpExecInfo {
97   TFE_Context* ctx;
98   const char* device_name;
99   // The op def of the main op being executed.
100   const tensorflow::OpDef* op_def;
101 
102   bool run_callbacks;
103   bool run_post_exec_callbacks;
104   bool run_gradient_callback;
105 
106   // The op name of the main op being executed.
107   PyObject* name;
108   // The op type name of the main op being executed.
109   PyObject* op_name;
110   PyObject* callbacks;
111 
112   // All the args passed into the FastPathOpExecInfo.
113   PyObject* args;
114 
115   // DTypes can come from another input that has the same attr. So build that
116   // map.
117   const AttrToInputsMap* attr_to_inputs_map;
118   tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
119 };
120 
121 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn)                       \
122   bool fn_name(const string& key, PyObject* py_value, TF_Status* status,     \
123                type* value) {                                                \
124     if (check_fn(py_value)) {                                                \
125       *value = static_cast<type>(parse_fn(py_value));                        \
126       return true;                                                           \
127     } else {                                                                 \
128       TF_SetStatus(status, TF_INVALID_ARGUMENT,                              \
129                    tensorflow::strings::StrCat(                              \
130                        "Expecting " #type " value for attr ", key, ", got ", \
131                        py_value->ob_type->tp_name)                           \
132                        .c_str());                                            \
133       return false;                                                          \
134     }                                                                        \
135   }
136 
137 #if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue,int,PyLong_Check,PyLong_AsLong)138 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
139 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
140 #else
141 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
142 #endif
143 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
144 #undef PARSE_VALUE
145 
146 #if PY_MAJOR_VERSION < 3
147 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
148                      int64_t* value) {
149   if (PyInt_Check(py_value)) {
150     *value = static_cast<int64_t>(PyInt_AsLong(py_value));
151     return true;
152   } else if (PyLong_Check(py_value)) {
153     *value = static_cast<int64_t>(PyLong_AsLong(py_value));
154     return true;
155   }
156   TF_SetStatus(
157       status, TF_INVALID_ARGUMENT,
158       tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
159                                   ", got ", py_value->ob_type->tp_name)
160           .c_str());
161   return false;
162 }
163 #endif
164 
TensorShapeNumDims(PyObject * value)165 Py_ssize_t TensorShapeNumDims(PyObject* value) {
166   const auto size = PySequence_Size(value);
167   if (size == -1) {
168     // TensorShape.__len__ raises an error in the scenario where the shape is an
169     // unknown, which needs to be cleared.
170     // TODO(nareshmodi): ensure that this is actually a TensorShape.
171     PyErr_Clear();
172   }
173   return size;
174 }
175 
IsInteger(PyObject * py_value)176 bool IsInteger(PyObject* py_value) {
177 #if PY_MAJOR_VERSION >= 3
178   return PyLong_Check(py_value);
179 #else
180   return PyInt_Check(py_value);
181 #endif
182 }
183 
184 // This function considers a Dimension._value of None to be valid, and sets the
185 // value to be -1 in that case.
ParseDimensionValue(const string & key,PyObject * py_value,TF_Status * status,int64_t * value)186 bool ParseDimensionValue(const string& key, PyObject* py_value,
187                          TF_Status* status, int64_t* value) {
188   if (IsInteger(py_value)) {
189     return ParseInt64Value(key, py_value, status, value);
190   }
191 
192   tensorflow::Safe_PyObjectPtr dimension_value(
193       PyObject_GetAttrString(py_value, "_value"));
194   if (dimension_value == nullptr) {
195     TF_SetStatus(
196         status, TF_INVALID_ARGUMENT,
197         tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
198                                     ", got ", py_value->ob_type->tp_name)
199             .c_str());
200     return false;
201   }
202 
203   if (dimension_value.get() == Py_None) {
204     *value = -1;
205     return true;
206   }
207 
208   return ParseInt64Value(key, dimension_value.get(), status, value);
209 }
210 
ParseStringValue(const string & key,PyObject * py_value,TF_Status * status,tensorflow::StringPiece * value)211 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
212                       tensorflow::StringPiece* value) {
213   if (PyBytes_Check(py_value)) {
214     Py_ssize_t size = 0;
215     char* buf = nullptr;
216     if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
217     *value = tensorflow::StringPiece(buf, size);
218     return true;
219   }
220 #if PY_MAJOR_VERSION >= 3
221   if (PyUnicode_Check(py_value)) {
222     Py_ssize_t size = 0;
223     const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
224     if (buf == nullptr) return false;
225     *value = tensorflow::StringPiece(buf, size);
226     return true;
227   }
228 #endif
229   TF_SetStatus(
230       status, TF_INVALID_ARGUMENT,
231       tensorflow::strings::StrCat("Expecting a string value for attr ", key,
232                                   ", got ", py_value->ob_type->tp_name)
233           .c_str());
234   return false;
235 }
236 
ParseBoolValue(const string & key,PyObject * py_value,TF_Status * status,unsigned char * value)237 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
238                     unsigned char* value) {
239   *value = PyObject_IsTrue(py_value);
240   return true;
241 }
242 
243 // The passed in py_value is expected to be an object of the python type
244 // dtypes.DType or an int.
ParseTypeValue(const string & key,PyObject * py_value,TF_Status * status,int * value)245 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
246                     int* value) {
247   if (IsInteger(py_value)) {
248     return ParseIntValue(key, py_value, status, value);
249   }
250 
251   tensorflow::Safe_PyObjectPtr py_type_enum(
252       PyObject_GetAttrString(py_value, "_type_enum"));
253   if (py_type_enum == nullptr) {
254     PyErr_Clear();
255     TF_SetStatus(
256         status, TF_INVALID_ARGUMENT,
257         tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
258                                     ", got ", py_value->ob_type->tp_name)
259             .c_str());
260     return false;
261   }
262 
263   return ParseIntValue(key, py_type_enum.get(), status, value);
264 }
265 
SetOpAttrList(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_list,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)266 bool SetOpAttrList(
267     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
268     TF_AttrType type,
269     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
270     TF_Status* status) {
271   if (!PySequence_Check(py_list)) {
272     TF_SetStatus(
273         status, TF_INVALID_ARGUMENT,
274         tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
275                                     ", got ", py_list->ob_type->tp_name)
276             .c_str());
277     return false;
278   }
279   const int num_values = PySequence_Size(py_list);
280   if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
281 
282 #define PARSE_LIST(c_type, parse_fn)                                      \
283   std::unique_ptr<c_type[]> values(new c_type[num_values]);               \
284   for (int i = 0; i < num_values; ++i) {                                  \
285     tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));   \
286     if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
287   }
288 
289   if (type == TF_ATTR_STRING) {
290     std::unique_ptr<const void*[]> values(new const void*[num_values]);
291     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
292     for (int i = 0; i < num_values; ++i) {
293       tensorflow::StringPiece value;
294       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
295       if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
296       values[i] = value.data();
297       lengths[i] = value.size();
298     }
299     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
300   } else if (type == TF_ATTR_INT) {
301     PARSE_LIST(int64_t, ParseInt64Value);
302     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
303   } else if (type == TF_ATTR_FLOAT) {
304     PARSE_LIST(float, ParseFloatValue);
305     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
306   } else if (type == TF_ATTR_BOOL) {
307     PARSE_LIST(unsigned char, ParseBoolValue);
308     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
309   } else if (type == TF_ATTR_TYPE) {
310     PARSE_LIST(int, ParseTypeValue);
311     TFE_OpSetAttrTypeList(op, key,
312                           reinterpret_cast<const TF_DataType*>(values.get()),
313                           num_values);
314   } else if (type == TF_ATTR_SHAPE) {
315     // Make one pass through the input counting the total number of
316     // dims across all the input lists.
317     int total_dims = 0;
318     for (int i = 0; i < num_values; ++i) {
319       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
320       if (py_value.get() != Py_None) {
321         if (!PySequence_Check(py_value.get())) {
322           TF_SetStatus(
323               status, TF_INVALID_ARGUMENT,
324               tensorflow::strings::StrCat(
325                   "Expecting None or sequence value for element", i,
326                   " of attr ", key, ", got ", py_value->ob_type->tp_name)
327                   .c_str());
328           return false;
329         }
330         const auto size = TensorShapeNumDims(py_value.get());
331         if (size >= 0) {
332           total_dims += size;
333         }
334       }
335     }
336     // Allocate a buffer that can fit all of the dims together.
337     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
338     // Copy the input dims into the buffer and set dims to point to
339     // the start of each list's dims.
340     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
341     std::unique_ptr<int[]> num_dims(new int[num_values]);
342     int64_t* offset = buffer.get();
343     for (int i = 0; i < num_values; ++i) {
344       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
345       if (py_value.get() == Py_None) {
346         dims[i] = nullptr;
347         num_dims[i] = -1;
348       } else {
349         const auto size = TensorShapeNumDims(py_value.get());
350         if (size == -1) {
351           dims[i] = nullptr;
352           num_dims[i] = -1;
353           continue;
354         }
355         dims[i] = offset;
356         num_dims[i] = size;
357         for (int j = 0; j < size; ++j) {
358           tensorflow::Safe_PyObjectPtr inner_py_value(
359               PySequence_ITEM(py_value.get(), j));
360           if (inner_py_value.get() == Py_None) {
361             *offset = -1;
362           } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
363                                           offset)) {
364             return false;
365           }
366           ++offset;
367         }
368       }
369     }
370     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
371                            status);
372     if (TF_GetCode(status) != TF_OK) return false;
373   } else if (type == TF_ATTR_FUNC) {
374     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
375     for (int i = 0; i < num_values; ++i) {
376       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
377       // Allow:
378       // (1) String function name, OR
379       // (2) A Python object with a .name attribute
380       //     (A crude test for being a
381       //     tensorflow.python.framework.function._DefinedFunction)
382       //     (which is what the various "defun" or "Defun" decorators do).
383       // And in the future also allow an object that can encapsulate
384       // the function name and its attribute values.
385       tensorflow::StringPiece func_name;
386       if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
387         PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
388         if (name_attr == nullptr ||
389             !ParseStringValue(key, name_attr, status, &func_name)) {
390           TF_SetStatus(
391               status, TF_INVALID_ARGUMENT,
392               tensorflow::strings::StrCat(
393                   "unable to set function value attribute from a ",
394                   py_value.get()->ob_type->tp_name,
395                   " object. If you think this is an error, please file an "
396                   "issue at "
397                   "https://github.com/tensorflow/tensorflow/issues/new")
398                   .c_str());
399           return false;
400         }
401       }
402       funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
403       if (TF_GetCode(status) != TF_OK) return false;
404     }
405     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
406     if (TF_GetCode(status) != TF_OK) return false;
407   } else {
408     TF_SetStatus(status, TF_UNIMPLEMENTED,
409                  tensorflow::strings::StrCat("Attr ", key,
410                                              " has unhandled list type ", type)
411                      .c_str());
412     return false;
413   }
414 #undef PARSE_LIST
415   return true;
416 }
417 
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)418 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
419                 TF_Status* status) {
420   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
421   for (const auto& attr : func.attr()) {
422     if (TF_GetCode(status) != TF_OK) return nullptr;
423     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
424     if (TF_GetCode(status) != TF_OK) return nullptr;
425   }
426   return func_op;
427 }
428 
SetOpAttrListDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * key,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)429 void SetOpAttrListDefault(
430     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
431     const char* key, TF_AttrType type,
432     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
433     TF_Status* status) {
434   if (type == TF_ATTR_STRING) {
435     int num_values = attr.default_value().list().s_size();
436     std::unique_ptr<const void*[]> values(new const void*[num_values]);
437     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
438     (*attr_list_sizes)[key] = num_values;
439     for (int i = 0; i < num_values; i++) {
440       const string& v = attr.default_value().list().s(i);
441       values[i] = v.data();
442       lengths[i] = v.size();
443     }
444     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
445   } else if (type == TF_ATTR_INT) {
446     int num_values = attr.default_value().list().i_size();
447     std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
448     (*attr_list_sizes)[key] = num_values;
449     for (int i = 0; i < num_values; i++) {
450       values[i] = attr.default_value().list().i(i);
451     }
452     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
453   } else if (type == TF_ATTR_FLOAT) {
454     int num_values = attr.default_value().list().f_size();
455     std::unique_ptr<float[]> values(new float[num_values]);
456     (*attr_list_sizes)[key] = num_values;
457     for (int i = 0; i < num_values; i++) {
458       values[i] = attr.default_value().list().f(i);
459     }
460     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
461   } else if (type == TF_ATTR_BOOL) {
462     int num_values = attr.default_value().list().b_size();
463     std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
464     (*attr_list_sizes)[key] = num_values;
465     for (int i = 0; i < num_values; i++) {
466       values[i] = attr.default_value().list().b(i);
467     }
468     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
469   } else if (type == TF_ATTR_TYPE) {
470     int num_values = attr.default_value().list().type_size();
471     std::unique_ptr<int[]> values(new int[num_values]);
472     (*attr_list_sizes)[key] = num_values;
473     for (int i = 0; i < num_values; i++) {
474       values[i] = attr.default_value().list().type(i);
475     }
476     TFE_OpSetAttrTypeList(op, key,
477                           reinterpret_cast<const TF_DataType*>(values.get()),
478                           attr.default_value().list().type_size());
479   } else if (type == TF_ATTR_SHAPE) {
480     int num_values = attr.default_value().list().shape_size();
481     (*attr_list_sizes)[key] = num_values;
482     int total_dims = 0;
483     for (int i = 0; i < num_values; ++i) {
484       if (!attr.default_value().list().shape(i).unknown_rank()) {
485         total_dims += attr.default_value().list().shape(i).dim_size();
486       }
487     }
488     // Allocate a buffer that can fit all of the dims together.
489     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
490     // Copy the input dims into the buffer and set dims to point to
491     // the start of each list's dims.
492     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
493     std::unique_ptr<int[]> num_dims(new int[num_values]);
494     int64_t* offset = buffer.get();
495     for (int i = 0; i < num_values; ++i) {
496       const auto& shape = attr.default_value().list().shape(i);
497       if (shape.unknown_rank()) {
498         dims[i] = nullptr;
499         num_dims[i] = -1;
500       } else {
501         for (int j = 0; j < shape.dim_size(); j++) {
502           *offset = shape.dim(j).size();
503           ++offset;
504         }
505       }
506     }
507     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
508                            status);
509   } else if (type == TF_ATTR_FUNC) {
510     int num_values = attr.default_value().list().func_size();
511     (*attr_list_sizes)[key] = num_values;
512     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
513     for (int i = 0; i < num_values; i++) {
514       funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
515     }
516     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
517   } else {
518     TF_SetStatus(status, TF_UNIMPLEMENTED,
519                  "Lists of tensors are not yet implemented for default valued "
520                  "attributes for an operation.");
521   }
522 }
523 
SetOpAttrScalar(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_value,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)524 bool SetOpAttrScalar(
525     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
526     TF_AttrType type,
527     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
528     TF_Status* status) {
529   if (type == TF_ATTR_STRING) {
530     tensorflow::StringPiece value;
531     if (!ParseStringValue(key, py_value, status, &value)) return false;
532     TFE_OpSetAttrString(op, key, value.data(), value.size());
533   } else if (type == TF_ATTR_INT) {
534     int64_t value;
535     if (!ParseInt64Value(key, py_value, status, &value)) return false;
536     TFE_OpSetAttrInt(op, key, value);
537     // attr_list_sizes is set for all int attributes (since at this point we are
538     // not aware if that attribute might be used to calculate the size of an
539     // output list or not).
540     if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
541   } else if (type == TF_ATTR_FLOAT) {
542     float value;
543     if (!ParseFloatValue(key, py_value, status, &value)) return false;
544     TFE_OpSetAttrFloat(op, key, value);
545   } else if (type == TF_ATTR_BOOL) {
546     unsigned char value;
547     if (!ParseBoolValue(key, py_value, status, &value)) return false;
548     TFE_OpSetAttrBool(op, key, value);
549   } else if (type == TF_ATTR_TYPE) {
550     int value;
551     if (!ParseTypeValue(key, py_value, status, &value)) return false;
552     TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
553   } else if (type == TF_ATTR_SHAPE) {
554     if (py_value == Py_None) {
555       TFE_OpSetAttrShape(op, key, nullptr, -1, status);
556     } else {
557       if (!PySequence_Check(py_value)) {
558         TF_SetStatus(status, TF_INVALID_ARGUMENT,
559                      tensorflow::strings::StrCat(
560                          "Expecting None or sequence value for attr", key,
561                          ", got ", py_value->ob_type->tp_name)
562                          .c_str());
563         return false;
564       }
565       const auto num_dims = TensorShapeNumDims(py_value);
566       if (num_dims == -1) {
567         TFE_OpSetAttrShape(op, key, nullptr, -1, status);
568         return true;
569       }
570       std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
571       for (int i = 0; i < num_dims; ++i) {
572         tensorflow::Safe_PyObjectPtr inner_py_value(
573             PySequence_ITEM(py_value, i));
574         if (inner_py_value.get() == Py_None) {
575           dims[i] = -1;
576         } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
577                                         &dims[i])) {
578           return false;
579         }
580       }
581       TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
582     }
583     if (TF_GetCode(status) != TF_OK) return false;
584   } else if (type == TF_ATTR_FUNC) {
585     // Allow:
586     // (1) String function name, OR
587     // (2) A Python object with a .name attribute
588     //     (A crude test for being a
589     //     tensorflow.python.framework.function._DefinedFunction)
590     //     (which is what the various "defun" or "Defun" decorators do).
591     // And in the future also allow an object that can encapsulate
592     // the function name and its attribute values.
593     tensorflow::StringPiece func_name;
594     if (!ParseStringValue(key, py_value, status, &func_name)) {
595       PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
596       if (name_attr == nullptr ||
597           !ParseStringValue(key, name_attr, status, &func_name)) {
598         TF_SetStatus(
599             status, TF_INVALID_ARGUMENT,
600             tensorflow::strings::StrCat(
601                 "unable to set function value attribute from a ",
602                 py_value->ob_type->tp_name,
603                 " object. If you think this is an error, please file an issue "
604                 "at https://github.com/tensorflow/tensorflow/issues/new")
605                 .c_str());
606         return false;
607       }
608     }
609     TF_SetStatus(status, TF_OK, "");
610     TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
611   } else {
612     TF_SetStatus(
613         status, TF_UNIMPLEMENTED,
614         tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
615             .c_str());
616     return false;
617   }
618   return true;
619 }
620 
SetOpAttrScalarDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)621 void SetOpAttrScalarDefault(
622     TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
623     const char* attr_name,
624     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
625     TF_Status* status) {
626   SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
627   if (default_value.value_case() == tensorflow::AttrValue::kI) {
628     (*attr_list_sizes)[attr_name] = default_value.i();
629   }
630 }
631 
632 // start_index is the index at which the Tuple/List attrs will start getting
633 // processed.
SetOpAttrs(TFE_Context * ctx,TFE_Op * op,PyObject * attrs,int start_index,TF_Status * out_status)634 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
635                 TF_Status* out_status) {
636   if (attrs == Py_None) return;
637   Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
638   if ((len & 1) != 0) {
639     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
640                  "Expecting attrs tuple to have even length.");
641     return;
642   }
643   // Parse attrs
644   for (Py_ssize_t i = 0; i < len; i += 2) {
645     PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
646     PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
647 #if PY_MAJOR_VERSION >= 3
648     const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
649                                             : PyUnicode_AsUTF8(py_key);
650 #else
651     const char* key = PyBytes_AsString(py_key);
652 #endif
653     unsigned char is_list = 0;
654     const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
655     if (TF_GetCode(out_status) != TF_OK) return;
656     if (is_list != 0) {
657       if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
658         return;
659     } else {
660       if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
661         return;
662     }
663   }
664 }
665 
666 // This function will set the op attrs required. If an attr has the value of
667 // None, then it will read the AttrDef to get the default value and set that
668 // instead. Any failure in this function will simply fall back to the slow
669 // path.
SetOpAttrWithDefaults(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * attr_name,PyObject * attr_value,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)670 void SetOpAttrWithDefaults(
671     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
672     const char* attr_name, PyObject* attr_value,
673     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
674     TF_Status* status) {
675   unsigned char is_list = 0;
676   const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
677   if (TF_GetCode(status) != TF_OK) return;
678   if (attr_value == Py_None) {
679     if (is_list != 0) {
680       SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
681                            status);
682     } else {
683       SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
684                              attr_list_sizes, status);
685     }
686   } else {
687     if (is_list != 0) {
688       SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
689                     status);
690     } else {
691       SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
692                       status);
693     }
694   }
695 }
696 
697 // Python subclass of Exception that is created on not ok Status.
698 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
699 PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
700 
701 // Python subclass of Exception that is created to signal fallback.
702 PyObject* fallback_exception_class = nullptr;
703 
704 // Python function that returns input gradients given output gradients.
705 PyObject* gradient_function = nullptr;
706 
707 PyTypeObject* resource_variable_type = nullptr;
708 
709 tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
710 tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
711 
712 }  // namespace
713 
TFE_Py_Execute(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_OutputTensorHandles * outputs,TF_Status * out_status)714 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
715                     const char* op_name, TFE_InputTensorHandles* inputs,
716                     PyObject* attrs, TFE_OutputTensorHandles* outputs,
717                     TF_Status* out_status) {
718   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
719   if (TF_GetCode(out_status) != TF_OK) return;
720   TFE_OpSetDevice(op, device_name, out_status);
721   if (TF_GetCode(out_status) == TF_OK) {
722     for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK;
723          ++i) {
724       TFE_OpAddInput(op, inputs->at(i), out_status);
725     }
726   }
727   if (TF_GetCode(out_status) == TF_OK) {
728     SetOpAttrs(ctx, op, attrs, 0, out_status);
729   }
730   Py_BEGIN_ALLOW_THREADS;
731   if (TF_GetCode(out_status) == TF_OK) {
732     int num_outputs = outputs->size();
733     TFE_Execute(op, outputs->data(), &num_outputs, out_status);
734     outputs->resize(num_outputs);
735   }
736   if (TF_GetCode(out_status) != TF_OK) {
737     TF_SetStatus(out_status, TF_GetCode(out_status),
738                  tensorflow::strings::StrCat(TF_Message(out_status),
739                                              " [Op:", op_name, "]")
740                      .c_str());
741   }
742   TFE_DeleteOp(op);
743   Py_END_ALLOW_THREADS;
744 }
745 
TFE_Py_RegisterExceptionClass(PyObject * e)746 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
747   tensorflow::mutex_lock l(exception_class_mutex);
748   if (exception_class != nullptr) {
749     Py_DECREF(exception_class);
750   }
751   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
752     exception_class = nullptr;
753     PyErr_SetString(PyExc_TypeError,
754                     "TFE_Py_RegisterExceptionClass: "
755                     "Registered class should be subclass of Exception.");
756     return nullptr;
757   }
758 
759   Py_INCREF(e);
760   exception_class = e;
761   Py_RETURN_NONE;
762 }
763 
TFE_Py_RegisterResourceVariableType(PyObject * e)764 PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e) {
765   if (!PyType_Check(e)) {
766     PyErr_SetString(
767         PyExc_TypeError,
768         "TFE_Py_RegisterResourceVariableType: Need to register a type.");
769     return nullptr;
770   }
771 
772   if (resource_variable_type != nullptr) {
773     Py_DECREF(resource_variable_type);
774   }
775 
776   Py_INCREF(e);
777   resource_variable_type = reinterpret_cast<PyTypeObject*>(e);
778   Py_RETURN_NONE;
779 }
780 
TFE_Py_RegisterFallbackExceptionClass(PyObject * e)781 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
782   if (fallback_exception_class != nullptr) {
783     Py_DECREF(fallback_exception_class);
784   }
785   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
786     fallback_exception_class = nullptr;
787     PyErr_SetString(PyExc_TypeError,
788                     "TFE_Py_RegisterFallbackExceptionClass: "
789                     "Registered class should be subclass of Exception.");
790     return nullptr;
791   } else {
792     Py_INCREF(e);
793     fallback_exception_class = e;
794     Py_RETURN_NONE;
795   }
796 }
797 
TFE_Py_RegisterGradientFunction(PyObject * e)798 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
799   if (gradient_function != nullptr) {
800     Py_DECREF(gradient_function);
801   }
802   if (!PyCallable_Check(e)) {
803     gradient_function = nullptr;
804     PyErr_SetString(PyExc_TypeError,
805                     "TFE_Py_RegisterBackwardFunctionGetter: "
806                     "Registered object should be function.");
807     return nullptr;
808   } else {
809     Py_INCREF(e);
810     gradient_function = e;
811     Py_RETURN_NONE;
812   }
813 }
814 
RaiseFallbackException(const char * message)815 void RaiseFallbackException(const char* message) {
816   if (fallback_exception_class != nullptr) {
817     PyErr_SetString(fallback_exception_class, message);
818     return;
819   }
820 
821   PyErr_SetString(
822       PyExc_RuntimeError,
823       tensorflow::strings::StrCat(
824           "Fallback exception type not set, attempting to fallback due to ",
825           message)
826           .data());
827 }
828 
MaybeRaiseExceptionFromTFStatus(TF_Status * status,PyObject * exception)829 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
830   if (TF_GetCode(status) == TF_OK) return 0;
831   const char* msg = TF_Message(status);
832   if (exception == nullptr) {
833     tensorflow::mutex_lock l(exception_class_mutex);
834     if (exception_class != nullptr) {
835       tensorflow::Safe_PyObjectPtr val(
836           Py_BuildValue("si", msg, TF_GetCode(status)));
837       if (PyErr_Occurred()) {
838         // NOTE: This hides the actual error (i.e. the reason `status` was not
839         // TF_OK), but there is nothing we can do at this point since we can't
840         // generate a reasonable error from the status.
841         // Consider adding a message explaining this.
842         return -1;
843       }
844       PyErr_SetObject(exception_class, val.get());
845       return -1;
846     } else {
847       exception = PyExc_RuntimeError;
848     }
849   }
850   // May be update already set exception.
851   PyErr_SetString(exception, msg);
852   return -1;
853 }
854 
MaybeRaiseExceptionFromStatus(const tensorflow::Status & status,PyObject * exception)855 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
856                                   PyObject* exception) {
857   if (status.ok()) return 0;
858   const char* msg = status.error_message().c_str();
859   if (exception == nullptr) {
860     tensorflow::mutex_lock l(exception_class_mutex);
861     if (exception_class != nullptr) {
862       tensorflow::Safe_PyObjectPtr val(Py_BuildValue("si", msg, status.code()));
863       PyErr_SetObject(exception_class, val.get());
864       return -1;
865     } else {
866       exception = PyExc_RuntimeError;
867     }
868   }
869   // May be update already set exception.
870   PyErr_SetString(exception, msg);
871   return -1;
872 }
873 
TFE_GetPythonString(PyObject * o)874 const char* TFE_GetPythonString(PyObject* o) {
875 #if PY_MAJOR_VERSION >= 3
876   if (PyBytes_Check(o)) {
877     return PyBytes_AsString(o);
878   } else {
879     return PyUnicode_AsUTF8(o);
880   }
881 #else
882   return PyBytes_AsString(o);
883 #endif
884 }
885 
get_uid()886 int64_t get_uid() {
887   tensorflow::mutex_lock l(_uid_mutex);
888   return _uid++;
889 }
890 
TFE_Py_UID()891 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
892 
TFE_DeleteContextCapsule(PyObject * context)893 void TFE_DeleteContextCapsule(PyObject* context) {
894   TFE_Context* ctx =
895       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
896   TFE_DeleteContext(ctx);
897 }
898 
MakeInt(PyObject * integer)899 static tensorflow::int64 MakeInt(PyObject* integer) {
900 #if PY_MAJOR_VERSION >= 3
901   return PyLong_AsLong(integer);
902 #else
903   return PyInt_AsLong(integer);
904 #endif
905 }
906 
FastTensorId(PyObject * tensor)907 static tensorflow::int64 FastTensorId(PyObject* tensor) {
908   if (EagerTensor_CheckExact(tensor)) {
909     return PyEagerTensor_ID(tensor);
910   }
911   PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
912   if (id_field == nullptr) {
913     return -1;
914   }
915   tensorflow::int64 id = MakeInt(id_field);
916   Py_DECREF(id_field);
917   return id;
918 }
919 
FastTensorDtype(PyObject * tensor)920 static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
921   if (EagerTensor_CheckExact(tensor)) {
922     return PyEagerTensor_Dtype(tensor);
923   }
924   PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
925   if (dtype_field == nullptr) {
926     return tensorflow::DT_INVALID;
927   }
928   PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
929   Py_DECREF(dtype_field);
930   if (dtype_field == nullptr) {
931     return tensorflow::DT_INVALID;
932   }
933   tensorflow::int64 id = MakeInt(enum_field);
934   Py_DECREF(enum_field);
935   return static_cast<tensorflow::DataType>(id);
936 }
937 
938 class PyTapeTensor {
939  public:
PyTapeTensor(tensorflow::int64 id,tensorflow::DataType dtype,const tensorflow::TensorShape & shape)940   PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
941                const tensorflow::TensorShape& shape)
942       : id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(tensorflow::int64 id,tensorflow::DataType dtype,PyObject * shape)943   PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
944                PyObject* shape)
945       : id_(id), dtype_(dtype), shape_(shape) {
946     Py_INCREF(absl::get<1>(shape_));
947   }
PyTapeTensor(const PyTapeTensor & other)948   PyTapeTensor(const PyTapeTensor& other) {
949     id_ = other.id_;
950     dtype_ = other.dtype_;
951     shape_ = other.shape_;
952     if (shape_.index() == 1) {
953       Py_INCREF(absl::get<1>(shape_));
954     }
955   }
956 
~PyTapeTensor()957   ~PyTapeTensor() {
958     if (shape_.index() == 1) {
959       Py_DECREF(absl::get<1>(shape_));
960     }
961   }
962   PyObject* GetShape() const;
GetDType() const963   PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
GetID() const964   tensorflow::int64 GetID() const { return id_; }
965 
966  private:
967   tensorflow::int64 id_;
968   tensorflow::DataType dtype_;
969   absl::variant<tensorflow::TensorShape, PyObject*> shape_;
970 };
971 
972 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
973                                                   PyTapeTensor> {
974  public:
PyVSpace(PyObject * py_vspace)975   explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
976     Py_INCREF(py_vspace_);
977   }
978 
Initialize()979   tensorflow::Status Initialize() {
980     num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
981     if (num_elements_ == nullptr) {
982       return tensorflow::errors::InvalidArgument("invalid vspace");
983     }
984     aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
985     if (aggregate_fn_ == nullptr) {
986       return tensorflow::errors::InvalidArgument("invalid vspace");
987     }
988     zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
989     if (zeros_fn_ == nullptr) {
990       return tensorflow::errors::InvalidArgument("invalid vspace");
991     }
992     ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
993     if (ones_fn_ == nullptr) {
994       return tensorflow::errors::InvalidArgument("invalid vspace");
995     }
996     graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
997     if (graph_shape_fn_ == nullptr) {
998       return tensorflow::errors::InvalidArgument("invalid vspace");
999     }
1000     return tensorflow::Status::OK();
1001   }
1002 
~PyVSpace()1003   ~PyVSpace() override {
1004     Py_XDECREF(num_elements_);
1005     Py_XDECREF(aggregate_fn_);
1006     Py_XDECREF(zeros_fn_);
1007     Py_XDECREF(ones_fn_);
1008     Py_XDECREF(graph_shape_fn_);
1009 
1010     Py_DECREF(py_vspace_);
1011   }
1012 
NumElements(PyObject * tensor) const1013   tensorflow::int64 NumElements(PyObject* tensor) const final {
1014     if (EagerTensor_CheckExact(tensor)) {
1015       return PyEagerTensor_NumElements(tensor);
1016     }
1017     PyObject* arglist =
1018         Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1019     PyObject* result = PyEval_CallObject(num_elements_, arglist);
1020     Py_DECREF(arglist);
1021     if (result == nullptr) {
1022       // The caller detects whether a python exception has been raised.
1023       return -1;
1024     }
1025     tensorflow::int64 r = MakeInt(result);
1026     Py_DECREF(result);
1027     return r;
1028   }
1029 
AggregateGradients(tensorflow::gtl::ArraySlice<PyObject * > gradient_tensors) const1030   PyObject* AggregateGradients(
1031       tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1032     PyObject* list = PyList_New(gradient_tensors.size());
1033     for (int i = 0; i < gradient_tensors.size(); ++i) {
1034       // Note: stealing a reference to the gradient tensors.
1035       CHECK(gradient_tensors[i] != nullptr);
1036       CHECK(gradient_tensors[i] != Py_None);
1037       PyList_SET_ITEM(list, i,
1038                       reinterpret_cast<PyObject*>(gradient_tensors[i]));
1039     }
1040     PyObject* arglist = Py_BuildValue("(O)", list);
1041     CHECK(arglist != nullptr);
1042     PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1043     Py_DECREF(arglist);
1044     Py_DECREF(list);
1045     return result;
1046   }
1047 
MarkAsResult(PyObject * gradient) const1048   void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1049 
Zeros(const PyTapeTensor & tensor) const1050   PyObject* Zeros(const PyTapeTensor& tensor) const final {
1051     if (PyErr_Occurred()) {
1052       return nullptr;
1053     }
1054     PyObject* py_shape = tensor.GetShape();
1055     if (PyErr_Occurred()) {
1056       return nullptr;
1057     }
1058     PyObject* py_dtype = tensor.GetDType();
1059     if (PyErr_Occurred()) {
1060       Py_DECREF(py_shape);
1061       return nullptr;
1062     }
1063     PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
1064     PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1065     Py_DECREF(arg_list);
1066     Py_DECREF(py_dtype);
1067     Py_DECREF(py_shape);
1068     return reinterpret_cast<PyObject*>(result);
1069   }
1070 
Ones(const PyTapeTensor & tensor) const1071   PyObject* Ones(const PyTapeTensor& tensor) const final {
1072     if (PyErr_Occurred()) {
1073       return nullptr;
1074     }
1075     PyObject* py_shape = tensor.GetShape();
1076     PyObject* py_dtype = tensor.GetDType();
1077     PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
1078     PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1079     Py_DECREF(arg_list);
1080     Py_DECREF(py_dtype);
1081     Py_DECREF(py_shape);
1082     return result;
1083   }
1084 
GraphShape(PyObject * tensor) const1085   PyObject* GraphShape(PyObject* tensor) const {
1086     PyObject* arg_list = Py_BuildValue("(O)", tensor);
1087     PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1088     Py_DECREF(arg_list);
1089     return result;
1090   }
1091 
CallBackwardFunction(PyBackwardFunction * backward_function,tensorflow::gtl::ArraySlice<PyObject * > output_gradients,std::vector<PyObject * > * result) const1092   tensorflow::Status CallBackwardFunction(
1093       PyBackwardFunction* backward_function,
1094       tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1095       std::vector<PyObject*>* result) const final {
1096     PyObject* grads = PyTuple_New(output_gradients.size());
1097     for (int i = 0; i < output_gradients.size(); ++i) {
1098       if (output_gradients[i] == nullptr) {
1099         Py_INCREF(Py_None);
1100         PyTuple_SET_ITEM(grads, i, Py_None);
1101       } else {
1102         PyTuple_SET_ITEM(grads, i,
1103                          reinterpret_cast<PyObject*>(output_gradients[i]));
1104       }
1105     }
1106     PyObject* py_result = (*backward_function)(grads);
1107     Py_DECREF(grads);
1108     if (py_result == nullptr) {
1109       return tensorflow::errors::Internal("gradient function threw exceptions");
1110     }
1111     result->clear();
1112     PyObject* seq =
1113         PySequence_Fast(py_result, "expected a sequence of gradients");
1114     if (seq == nullptr) {
1115       return tensorflow::errors::InvalidArgument(
1116           "gradient function did not return a list");
1117     }
1118     int len = PySequence_Fast_GET_SIZE(seq);
1119     VLOG(1) << "Gradient length is " << len;
1120     result->reserve(len);
1121     for (int i = 0; i < len; ++i) {
1122       PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
1123       if (item == Py_None) {
1124         result->push_back(nullptr);
1125       } else {
1126         Py_INCREF(item);
1127         result->push_back(item);
1128       }
1129     }
1130     Py_DECREF(seq);
1131     Py_DECREF(py_result);
1132     return tensorflow::Status::OK();
1133   }
1134 
DeleteGradient(PyObject * tensor) const1135   void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1136 
1137  private:
1138   PyObject* py_vspace_;
1139 
1140   PyObject* num_elements_;
1141   PyObject* aggregate_fn_;
1142   PyObject* zeros_fn_;
1143   PyObject* ones_fn_;
1144   PyObject* graph_shape_fn_;
1145 };
1146 PyVSpace* py_vspace = nullptr;
1147 
TFE_Py_RegisterVSpace(PyObject * e)1148 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1149   if (py_vspace != nullptr) {
1150     delete py_vspace;
1151   }
1152 
1153   py_vspace = new PyVSpace(e);
1154   auto status = py_vspace->Initialize();
1155   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1156     delete py_vspace;
1157     return nullptr;
1158   }
1159 
1160   Py_RETURN_NONE;
1161 }
1162 
GetShape() const1163 PyObject* PyTapeTensor::GetShape() const {
1164   if (shape_.index() == 0) {
1165     auto& shape = absl::get<0>(shape_);
1166     PyObject* py_shape = PyTuple_New(shape.dims());
1167     for (int i = 0; i < shape.dims(); ++i) {
1168       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1169     }
1170 
1171     return py_shape;
1172   }
1173 
1174   return py_vspace->GraphShape(absl::get<1>(shape_));
1175 }
1176 
1177 class GradientTape
1178     : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1179                                              PyTapeTensor> {
1180  public:
GradientTape(bool persistent,bool watch_accessed_variables)1181   explicit GradientTape(bool persistent, bool watch_accessed_variables)
1182       : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1183                                         PyTapeTensor>(persistent),
1184         watch_accessed_variables_(watch_accessed_variables) {}
1185 
~GradientTape()1186   virtual ~GradientTape() {
1187     for (const IdAndVariable& v : watched_variables_) {
1188       Py_DECREF(v.variable);
1189     }
1190   }
1191 
VariableAccessed(PyObject * v)1192   void VariableAccessed(PyObject* v) {
1193     if (watch_accessed_variables_) {
1194       WatchVariable(v);
1195     }
1196   }
1197 
WatchVariable(PyObject * v)1198   void WatchVariable(PyObject* v) {
1199     tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1200     if (handle == nullptr) {
1201       return;
1202     }
1203     tensorflow::int64 id = FastTensorId(handle.get());
1204 
1205     if (!PyErr_Occurred()) {
1206       this->Watch(id);
1207     }
1208 
1209     tensorflow::mutex_lock l(watched_variables_mu_);
1210     auto insert_result = watched_variables_.emplace(id, v);
1211 
1212     if (insert_result.second) {
1213       // Only increment the reference count if we aren't already watching this
1214       // variable.
1215       Py_INCREF(v);
1216     }
1217   }
1218 
GetVariablesAsPyTuple()1219   PyObject* GetVariablesAsPyTuple() {
1220     tensorflow::mutex_lock l(watched_variables_mu_);
1221     PyObject* result = PyTuple_New(watched_variables_.size());
1222     Py_ssize_t pos = 0;
1223     for (const IdAndVariable& id_and_variable : watched_variables_) {
1224       PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1225       Py_INCREF(id_and_variable.variable);
1226     }
1227     return result;
1228   }
1229 
1230  private:
1231   // We store an IdAndVariable in the map since the map needs to be locked
1232   // during insert, but should not call back into python during insert to avoid
1233   // deadlocking with the GIL.
1234   struct IdAndVariable {
1235     tensorflow::int64 id;
1236     PyObject* variable;
1237 
IdAndVariableGradientTape::IdAndVariable1238     IdAndVariable(tensorflow::int64 id, PyObject* variable)
1239         : id(id), variable(variable) {}
1240   };
1241   struct CompareById {
operator ()GradientTape::CompareById1242     bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1243       return lhs.id < rhs.id;
1244     }
1245   };
1246 
1247   bool watch_accessed_variables_;
1248   tensorflow::mutex watched_variables_mu_;
1249   std::set<IdAndVariable, CompareById> watched_variables_
1250       GUARDED_BY(watched_variables_mu_);
1251 };
1252 
1253 typedef struct {
1254   PyObject_HEAD
1255       /* Type-specific fields go here. */
1256       GradientTape* tape;
1257 } TFE_Py_Tape;
1258 
TFE_Py_Tape_Delete(PyObject * tape)1259 static void TFE_Py_Tape_Delete(PyObject* tape) {
1260   delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1261   Py_TYPE(tape)->tp_free(tape);
1262 }
1263 
1264 static PyTypeObject TFE_Py_Tape_Type = {
1265     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1266     sizeof(TFE_Py_Tape),                          /* tp_basicsize */
1267     0,                                            /* tp_itemsize */
1268     &TFE_Py_Tape_Delete,                          /* tp_dealloc */
1269     nullptr,                                      /* tp_print */
1270     nullptr,                                      /* tp_getattr */
1271     nullptr,                                      /* tp_setattr */
1272     nullptr,                                      /* tp_reserved */
1273     nullptr,                                      /* tp_repr */
1274     nullptr,                                      /* tp_as_number */
1275     nullptr,                                      /* tp_as_sequence */
1276     nullptr,                                      /* tp_as_mapping */
1277     nullptr,                                      /* tp_hash  */
1278     nullptr,                                      /* tp_call */
1279     nullptr,                                      /* tp_str */
1280     nullptr,                                      /* tp_getattro */
1281     nullptr,                                      /* tp_setattro */
1282     nullptr,                                      /* tp_as_buffer */
1283     Py_TPFLAGS_DEFAULT,                           /* tp_flags */
1284     "TFE_Py_Tape objects",                        /* tp_doc */
1285 };
1286 
1287 // Note: in the current design no mutex is needed here because of the python
1288 // GIL, which is always held when any TFE_Py_* methods are called. We should
1289 // revisit this if/when decide to not hold the GIL while manipulating the tape
1290 // stack.
GetTapeSet()1291 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1292   thread_local tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set{
1293       nullptr};
1294   if (tape_set == nullptr) {
1295     tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>;
1296   }
1297   return tape_set;
1298 }
1299 
1300 // A safe copy of the current tapeset. Does not get affected by other python
1301 // threads changing the set of active tapes.
1302 class SafeTapeSet {
1303  public:
SafeTapeSet()1304   SafeTapeSet() : tape_set_(*GetTapeSet()) {
1305     for (auto* tape : tape_set_) {
1306       Py_INCREF(tape);
1307     }
1308   }
1309 
~SafeTapeSet()1310   ~SafeTapeSet() {
1311     for (auto* tape : tape_set_) {
1312       Py_DECREF(tape);
1313     }
1314   }
1315 
begin()1316   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator begin() {
1317     return tape_set_.begin();
1318   }
1319 
end()1320   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator end() {
1321     return tape_set_.end();
1322   }
1323 
1324  private:
1325   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_;
1326 };
1327 
ThreadTapeIsStopped()1328 bool* ThreadTapeIsStopped() {
1329   thread_local bool thread_tape_is_stopped{false};
1330   return &thread_tape_is_stopped;
1331 }
1332 
TFE_Py_TapeSetStopOnThread()1333 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1334 
TFE_Py_TapeSetRestartOnThread()1335 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1336 
TFE_Py_TapeSetNew(PyObject * persistent,PyObject * watch_accessed_variables)1337 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1338                             PyObject* watch_accessed_variables) {
1339   TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1340   if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1341   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1342   tape->tape = new GradientTape(persistent == Py_True,
1343                                 watch_accessed_variables == Py_True);
1344   Py_INCREF(tape);
1345   GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
1346   return reinterpret_cast<PyObject*>(tape);
1347 }
1348 
TFE_Py_TapeSetAdd(PyObject * tape)1349 void TFE_Py_TapeSetAdd(PyObject* tape) {
1350   Py_INCREF(tape);
1351   if (!GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)).second) {
1352     // Already exists in the tape set.
1353     Py_DECREF(tape);
1354   }
1355 }
1356 
TFE_Py_TapeSetIsEmpty()1357 PyObject* TFE_Py_TapeSetIsEmpty() {
1358   if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
1359     Py_RETURN_TRUE;
1360   }
1361   Py_RETURN_FALSE;
1362 }
1363 
TFE_Py_TapeSetRemove(PyObject * tape)1364 void TFE_Py_TapeSetRemove(PyObject* tape) {
1365   auto* stack = GetTapeSet();
1366   stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1367   // We kept a reference to the tape in the set to ensure it wouldn't get
1368   // deleted under us; cleaning it up here.
1369   Py_DECREF(tape);
1370 }
1371 
MakeIntList(PyObject * list)1372 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
1373   if (list == Py_None) {
1374     return {};
1375   }
1376   PyObject* seq = PySequence_Fast(list, "expected a sequence");
1377   if (seq == nullptr) {
1378     return {};
1379   }
1380   int len = PySequence_Size(list);
1381   std::vector<tensorflow::int64> tensor_ids;
1382   tensor_ids.reserve(len);
1383   for (int i = 0; i < len; ++i) {
1384     PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
1385 #if PY_MAJOR_VERSION >= 3
1386     if (PyLong_Check(item)) {
1387 #else
1388     if (PyLong_Check(item) || PyInt_Check(item)) {
1389 #endif
1390       tensorflow::int64 id = MakeInt(item);
1391       tensor_ids.push_back(id);
1392     } else {
1393       tensor_ids.push_back(-1);
1394     }
1395   }
1396   Py_DECREF(seq);
1397   return tensor_ids;
1398 }
1399 
1400 PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
1401   if (tensors == Py_None) {
1402     Py_RETURN_FALSE;
1403   }
1404   if (*ThreadTapeIsStopped()) {
1405     Py_RETURN_FALSE;
1406   }
1407   auto* tape_set_ptr = GetTapeSet();
1408   if (tape_set_ptr->empty()) {
1409     Py_RETURN_FALSE;
1410   }
1411   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
1412   if (seq == nullptr) {
1413     return nullptr;
1414   }
1415   int len = PySequence_Fast_GET_SIZE(seq);
1416   // TODO(apassos) consider not building a list and changing the API to check
1417   // each tensor individually.
1418   std::vector<tensorflow::int64> tensor_ids;
1419   std::vector<tensorflow::DataType> dtypes;
1420   tensor_ids.reserve(len);
1421   dtypes.reserve(len);
1422   for (int i = 0; i < len; ++i) {
1423     PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
1424     tensor_ids.push_back(FastTensorId(item));
1425     dtypes.push_back(FastTensorDtype(item));
1426   }
1427   Py_DECREF(seq);
1428   auto tape_set = *tape_set_ptr;
1429   for (TFE_Py_Tape* tape : tape_set) {
1430     if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
1431       Py_RETURN_TRUE;
1432     }
1433   }
1434   Py_RETURN_FALSE;
1435 }
1436 
1437 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
1438   if (*ThreadTapeIsStopped()) {
1439     return;
1440   }
1441   tensorflow::int64 tensor_id = FastTensorId(tensor);
1442   if (PyErr_Occurred()) {
1443     return;
1444   }
1445   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
1446 }
1447 
1448 bool ListContainsNone(PyObject* list) {
1449   if (list == Py_None) return true;
1450   tensorflow::Safe_PyObjectPtr seq(
1451       PySequence_Fast(list, "expected a sequence"));
1452   if (seq == nullptr) {
1453     return false;
1454   }
1455 
1456   int len = PySequence_Size(list);
1457   for (int i = 0; i < len; ++i) {
1458     PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
1459     if (item == Py_None) return true;
1460   }
1461 
1462   return false;
1463 }
1464 
1465 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
1466   if (EagerTensor_CheckExact(tensor)) {
1467     TFE_TensorHandle* t = EagerTensor_Handle(tensor);
1468     tensorflow::int64 id = PyEagerTensor_ID(tensor);
1469     tensorflow::TensorShape tensor_shape;
1470     const tensorflow::Status status = t->handle->Shape(&tensor_shape);
1471 
1472     if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1473       return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
1474                           tensorflow::TensorShape({}));
1475     } else {
1476       return PyTapeTensor(id, t->handle->dtype, tensor_shape);
1477     }
1478   }
1479   tensorflow::int64 id = FastTensorId(tensor);
1480   if (PyErr_Occurred()) {
1481     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
1482                         tensorflow::TensorShape({}));
1483   }
1484   PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
1485   PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
1486   Py_DECREF(dtype_object);
1487   tensorflow::DataType dtype =
1488       static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
1489   Py_DECREF(dtype_enum);
1490   if (PyErr_Occurred()) {
1491     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
1492                         tensorflow::TensorShape({}));
1493   }
1494   static char _shape_tuple[] = "_shape_tuple";
1495   PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
1496   if (PyErr_Occurred()) {
1497     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
1498                         tensorflow::TensorShape({}));
1499   }
1500 
1501   if (ListContainsNone(shape_tuple)) {
1502     return PyTapeTensor(id, dtype, tensor);
1503   }
1504 
1505   auto l = MakeIntList(shape_tuple);
1506   Py_DECREF(shape_tuple);
1507   // Replace -1, which represents accidental Nones which can occur in graph mode
1508   // and can cause errors in shape cosntruction with 0s.
1509   for (auto& c : l) {
1510     if (c < 0) {
1511       c = 0;
1512     }
1513   }
1514   tensorflow::TensorShape shape(l);
1515   return PyTapeTensor(id, dtype, shape);
1516 }
1517 
1518 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
1519   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
1520   if (seq == nullptr) {
1521     return {};
1522   }
1523   int len = PySequence_Fast_GET_SIZE(seq);
1524   std::vector<tensorflow::int64> list;
1525   list.reserve(len);
1526   for (int i = 0; i < len; ++i) {
1527     PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
1528     list.push_back(FastTensorId(tensor));
1529     if (PyErr_Occurred()) {
1530       Py_DECREF(seq);
1531       return list;
1532     }
1533   }
1534   Py_DECREF(seq);
1535   return list;
1536 }
1537 
1538 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
1539   if (*ThreadTapeIsStopped()) {
1540     return;
1541   }
1542   for (TFE_Py_Tape* tape : SafeTapeSet()) {
1543     tape->tape->VariableAccessed(variable);
1544   }
1545 }
1546 
1547 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
1548   if (*ThreadTapeIsStopped()) {
1549     return;
1550   }
1551   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
1552 }
1553 
1554 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
1555   return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
1556 }
1557 
1558 namespace {
1559 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
1560   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
1561   if (seq == nullptr) {
1562     return {};
1563   }
1564   int len = PySequence_Fast_GET_SIZE(seq);
1565   std::vector<tensorflow::DataType> list;
1566   list.reserve(len);
1567   for (int i = 0; i < len; ++i) {
1568     PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
1569     list.push_back(FastTensorDtype(tensor));
1570   }
1571   Py_DECREF(seq);
1572   return list;
1573 }
1574 
1575 void TapeSetRecordOperation(
1576     PyObject* op_type, PyObject* output_tensors,
1577     const std::vector<tensorflow::int64>& input_ids,
1578     const std::vector<tensorflow::DataType>& input_dtypes,
1579     const std::function<PyBackwardFunction*()>& backward_function_getter,
1580     const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
1581   std::vector<PyTapeTensor> output_info;
1582   PyObject* seq = PySequence_Fast(output_tensors,
1583                                   "expected a sequence of integer tensor ids");
1584   int len = PySequence_Size(output_tensors);
1585   if (PyErr_Occurred()) return;
1586   output_info.reserve(len);
1587   for (int i = 0; i < len; ++i) {
1588     output_info.push_back(
1589         TapeTensorFromTensor(PySequence_Fast_GET_ITEM(seq, i)));
1590     if (PyErr_Occurred() != nullptr) {
1591       Py_DECREF(seq);
1592       return;
1593     }
1594   }
1595   Py_DECREF(seq);
1596   string op_type_str;
1597   if (PyBytes_Check(op_type)) {
1598     op_type_str = PyBytes_AsString(op_type);
1599   } else if (PyUnicode_Check(op_type)) {
1600 #if PY_MAJOR_VERSION >= 3
1601     op_type_str = PyUnicode_AsUTF8(op_type);
1602 #else
1603     PyObject* py_str = PyUnicode_AsUTF8String(op_type);
1604     if (py_str == nullptr) return;
1605     op_type_str = PyBytes_AS_STRING(py_str);
1606     Py_DECREF(py_str);
1607 #endif
1608   } else {
1609     PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
1610     return;
1611   }
1612 
1613   for (TFE_Py_Tape* tape : SafeTapeSet()) {
1614     tape->tape->RecordOperation(op_type_str, output_info, input_ids,
1615                                 input_dtypes, backward_function_getter,
1616                                 backward_function_killer);
1617   }
1618 }
1619 }  // namespace
1620 
1621 void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
1622                                    PyObject* input_tensors,
1623                                    PyObject* backward_function) {
1624   if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
1625     return;
1626   }
1627   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
1628   if (PyErr_Occurred()) return;
1629 
1630   std::vector<tensorflow::DataType> input_dtypes =
1631       MakeTensorDtypeList(input_tensors);
1632   if (PyErr_Occurred()) return;
1633 
1634   TapeSetRecordOperation(
1635       op_type, output_tensors, input_ids, input_dtypes,
1636       [backward_function]() {
1637         Py_INCREF(backward_function);
1638         PyBackwardFunction* function =
1639             new PyBackwardFunction([backward_function](PyObject* out_grads) {
1640               return PyObject_CallObject(backward_function, out_grads);
1641             });
1642         return function;
1643       },
1644       [backward_function](PyBackwardFunction* py_backward_function) {
1645         Py_DECREF(backward_function);
1646         delete py_backward_function;
1647       });
1648 }
1649 
1650 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
1651   for (TFE_Py_Tape* tape : SafeTapeSet()) {
1652     tape->tape->DeleteTrace(tensor_id);
1653   }
1654 }
1655 
1656 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
1657   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
1658   if (seq == nullptr) {
1659     return {};
1660   }
1661   int len = PySequence_Fast_GET_SIZE(seq);
1662   std::vector<PyObject*> list;
1663   list.reserve(len);
1664   for (int i = 0; i < len; ++i) {
1665     list.push_back(PySequence_Fast_GET_ITEM(seq, i));
1666   }
1667   Py_DECREF(seq);
1668   return list;
1669 }
1670 
1671 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
1672                               PyObject* sources, PyObject* output_gradients,
1673                               PyObject* unconnected_gradients,
1674                               TF_Status* status) {
1675   TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
1676   if (!tape_obj->tape->IsPersistent()) {
1677     auto* tape_set = GetTapeSet();
1678     if (tape_set->find(tape_obj) != tape_set->end()) {
1679       PyErr_SetString(PyExc_RuntimeError,
1680                       "gradient() cannot be invoked within the "
1681                       "GradientTape context (i.e., while operations are being "
1682                       "recorded). Either move the call to gradient() to be "
1683                       "outside the 'with tf.GradientTape' block, or "
1684                       "use a persistent tape: "
1685                       "'with tf.GradientTape(persistent=true)'");
1686       return nullptr;
1687     }
1688   }
1689 
1690   std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
1691   if (PyErr_Occurred()) {
1692     return nullptr;
1693   }
1694   std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
1695   if (PyErr_Occurred()) {
1696     return nullptr;
1697   }
1698   tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
1699                                                           sources_vec.end());
1700 
1701   tensorflow::Safe_PyObjectPtr seq =
1702       tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
1703   int len = PySequence_Fast_GET_SIZE(seq.get());
1704   tensorflow::gtl::FlatMap<tensorflow::int64, PyTapeTensor>
1705       source_tensors_that_are_targets;
1706   for (int i = 0; i < len; ++i) {
1707     tensorflow::int64 target_id = target_vec[i];
1708     if (sources_set.find(target_id) != sources_set.end()) {
1709       auto tensor = PySequence_Fast_GET_ITEM(seq.get(), i);
1710       source_tensors_that_are_targets.insert(
1711           std::make_pair(target_id, TapeTensorFromTensor(tensor)));
1712     }
1713     if (PyErr_Occurred()) {
1714       return nullptr;
1715     }
1716   }
1717   if (PyErr_Occurred()) {
1718     return nullptr;
1719   }
1720 
1721   std::vector<PyObject*> outgrad_vec;
1722   if (output_gradients != Py_None) {
1723     outgrad_vec = MakeTensorList(output_gradients);
1724     if (PyErr_Occurred()) {
1725       return nullptr;
1726     }
1727     for (PyObject* tensor : outgrad_vec) {
1728       // Calling the backward function will eat a reference to the tensors in
1729       // outgrad_vec, so we need to increase their reference count.
1730       Py_INCREF(tensor);
1731     }
1732   }
1733   std::vector<PyObject*> result;
1734   status->status = tape_obj->tape->ComputeGradient(
1735       *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
1736       outgrad_vec, &result);
1737   if (!status->status.ok()) {
1738     if (PyErr_Occurred()) {
1739       // Do not propagate the erroneous status as that would swallow the
1740       // exception which caused the problem.
1741       status->status = tensorflow::Status::OK();
1742     }
1743     return nullptr;
1744   }
1745 
1746   bool unconnected_gradients_zero =
1747       strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
1748   std::vector<PyObject*> sources_obj;
1749   if (unconnected_gradients_zero) {
1750     sources_obj = MakeTensorList(sources);
1751   }
1752 
1753   if (!result.empty()) {
1754     PyObject* py_result = PyList_New(result.size());
1755     tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
1756     for (int i = 0; i < result.size(); ++i) {
1757       if (result[i] == nullptr) {
1758         if (unconnected_gradients_zero) {
1759           // generate a zeros tensor in the shape of sources[i]
1760           tensorflow::DataType dtype = FastTensorDtype(sources_obj[i]);
1761           PyTapeTensor tensor =
1762               PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
1763           result[i] = py_vspace->Zeros(tensor);
1764         } else {
1765           Py_INCREF(Py_None);
1766           result[i] = Py_None;
1767         }
1768       } else if (seen_results.find(result[i]) != seen_results.end()) {
1769         Py_INCREF(result[i]);
1770       }
1771       seen_results.insert(result[i]);
1772       PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
1773     }
1774     return py_result;
1775   }
1776   return PyList_New(0);
1777 }
1778 
1779 namespace {
1780 static const int kFastPathExecuteInputStartIndex = 5;
1781 
1782 PyObject* GetPythonObjectFromString(const char* s) {
1783 #if PY_MAJOR_VERSION >= 3
1784   return PyUnicode_FromString(s);
1785 #else
1786   return PyBytes_FromString(s);
1787 #endif
1788 }
1789 
1790 PyObject* GetPythonObjectFromInt(int num) {
1791 #if PY_MAJOR_VERSION >= 3
1792   return PyLong_FromLong(num);
1793 #else
1794   return PyInt_FromLong(num);
1795 #endif
1796 }
1797 
1798 bool CheckResourceVariable(PyObject* item) {
1799   return PyObject_TypeCheck(item, resource_variable_type);
1800 }
1801 
1802 bool IsNumberType(PyObject* item) {
1803 #if PY_MAJOR_VERSION >= 3
1804   return PyFloat_Check(item) || PyLong_Check(item);
1805 #else
1806   return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
1807 #endif
1808 }
1809 
1810 bool CheckOneInput(PyObject* item) {
1811   if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
1812       PyArray_Check(item) || IsNumberType(item)) {
1813     return true;
1814   }
1815 
1816   // Sequences are not properly handled. Sequences with purely python numeric
1817   // types work, but sequences with mixes of EagerTensors and python numeric
1818   // types don't work.
1819   // TODO(nareshmodi): fix
1820   return false;
1821 }
1822 
1823 bool CheckInputsOk(PyObject* seq, int start_index,
1824                    const tensorflow::OpDef& op_def) {
1825   for (int i = 0; i < op_def.input_arg_size(); i++) {
1826     PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
1827     if (!op_def.input_arg(i).number_attr().empty() ||
1828         !op_def.input_arg(i).type_list_attr().empty()) {
1829       // This item should be a seq input.
1830       if (!PySequence_Check(item)) {
1831         VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
1832                 << "\", Input \"" << op_def.input_arg(i).name()
1833                 << "\" since we expected a sequence, but got "
1834                 << item->ob_type->tp_name;
1835         return false;
1836       }
1837       for (Py_ssize_t j = 0; j < PySequence_Fast_GET_SIZE(item); j++) {
1838         PyObject* inner_item = PySequence_Fast_GET_ITEM(item, j);
1839         if (!CheckOneInput(inner_item)) {
1840           VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
1841                   << "\", Input \"" << op_def.input_arg(i).name()
1842                   << "\", Index " << j
1843                   << " since we expected an EagerTensor/ResourceVariable, "
1844                      "but got "
1845                   << inner_item->ob_type->tp_name;
1846           return false;
1847         }
1848       }
1849     } else if (!CheckOneInput(item)) {
1850       VLOG(1)
1851           << "Falling back to slow path for Op \"" << op_def.name()
1852           << "\", Input \"" << op_def.input_arg(i).name()
1853           << "\" since we expected an EagerTensor/ResourceVariable, but got "
1854           << item->ob_type->tp_name;
1855       return false;
1856     }
1857   }
1858 
1859   return true;
1860 }
1861 
1862 PyObject* MaybeGetDType(PyObject* item) {
1863   if (EagerTensor_CheckExact(item)) {
1864     tensorflow::Safe_PyObjectPtr py_dtype(
1865         PyObject_GetAttrString(item, "dtype"));
1866     return PyObject_GetAttrString(py_dtype.get(), "_type_enum");
1867   }
1868 
1869   if (CheckResourceVariable(item)) {
1870     tensorflow::Safe_PyObjectPtr py_dtype(
1871         PyObject_GetAttrString(item, "_dtype"));
1872     return PyObject_GetAttrString(py_dtype.get(), "_type_enum");
1873   }
1874 
1875   return nullptr;
1876 }
1877 
1878 PyObject* MaybeGetDTypeForAttr(const string& attr,
1879                                FastPathOpExecInfo* op_exec_info) {
1880   auto cached_it = op_exec_info->cached_dtypes.find(attr);
1881   if (cached_it != op_exec_info->cached_dtypes.end()) {
1882     return GetPythonObjectFromInt(cached_it->second);
1883   }
1884 
1885   auto it = op_exec_info->attr_to_inputs_map->find(attr);
1886   if (it == op_exec_info->attr_to_inputs_map->end()) {
1887     // No other inputs - this should never happen.
1888     Py_RETURN_NONE;
1889   }
1890 
1891   for (const auto& input_info : it->second) {
1892     PyObject* item = PyTuple_GET_ITEM(
1893         op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i);
1894     if (input_info.is_list) {
1895       for (int i = 0; i < PySequence_Fast_GET_SIZE(item); i++) {
1896         auto* dtype = MaybeGetDType(PySequence_Fast_GET_ITEM(item, i));
1897         if (dtype != nullptr) return dtype;
1898       }
1899     } else {
1900       auto* dtype = MaybeGetDType(item);
1901       if (dtype != nullptr) return dtype;
1902     }
1903   }
1904 
1905   Py_RETURN_NONE;
1906 }
1907 
1908 // TODO(agarwal): use an automatic mechanism for handling None arguments to
1909 // gradient functions.
1910 
1911 // Returns a pair where the first value of the pair indicates whether or not all
1912 // outputs are unused. If the first value is false, the second value is a
1913 // set that identifies which of the output indices are unused.
1914 bool OpGradientDoesntRequireOutputIndices(
1915     const string& op_name,
1916     std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
1917   static tensorflow::gtl::FlatMap<
1918       string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
1919       new tensorflow::gtl::FlatMap<
1920           string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
1921           // Ops that don't require any outputs.
1922           {"Identity", {true, {}}},
1923           {"MatMul", {true, {}}},
1924           {"Conv2DBackpropInput", {true, {}}},
1925           {"Conv2DBackpropFilter", {true, {}}},
1926           {"Conv3D", {true, {}}},
1927           {"Conv3DBackpropInputV2", {true, {}}},
1928           {"AvgPool3D", {true, {}}},
1929           {"AvgPool3DGrad", {true, {}}},
1930           {"MaxPool3D", {false, {}}},
1931           {"MaxPool3DGrad", {true, {}}},
1932           {"MaxPool3DGradGrad", {true, {}}},
1933           {"BiasAdd", {true, {}}},
1934           {"BiasAddV1", {true, {}}},
1935           {"BiasAddGrad", {true, {}}},
1936           {"Softplus", {true, {}}},
1937           {"SoftplusGrad", {true, {}}},
1938           {"Softsign", {true, {}}},
1939           {"ReluGrad", {true, {}}},
1940           {"LeakyRelu", {true, {}}},
1941           {"LeakyReluGrad", {true, {}}},
1942           {"Conv2D", {true, {}}},
1943           {"DepthwiseConv2dNative", {true, {}}},
1944           {"Dilation2D", {true, {}}},
1945           {"AvgPool", {true, {}}},
1946           {"AvgPoolGrad", {true, {}}},
1947           {"BatchNormWithGlobalNormalization", {true, {}}},
1948           {"L2Loss", {true, {}}},
1949           {"Sum", {true, {}}},
1950           {"Prod", {true, {}}},
1951           {"SegmentSum", {true, {}}},
1952           {"SegmentMean", {true, {}}},
1953           {"SparseSegmentSum", {true, {}}},
1954           {"SparseSegmentMean", {true, {}}},
1955           {"SparseSegmentSqrtN", {true, {}}},
1956           {"UnsortedSegmentSum", {true, {}}},
1957           {"UnsortedSegmentMax", {true, {}}},
1958           {"Abs", {true, {}}},
1959           {"Neg", {true, {}}},
1960           {"ReciprocalGrad", {true, {}}},
1961           {"Square", {true, {}}},
1962           {"Expm1", {true, {}}},
1963           {"Log", {true, {}}},
1964           {"Log1p", {true, {}}},
1965           {"TanhGrad", {true, {}}},
1966           {"SigmoidGrad", {true, {}}},
1967           {"Sign", {true, {}}},
1968           {"Sin", {true, {}}},
1969           {"Cos", {true, {}}},
1970           {"Tan", {true, {}}},
1971           {"Add", {true, {}}},
1972           {"Sub", {true, {}}},
1973           {"Mul", {true, {}}},
1974           {"Div", {true, {}}},
1975           {"RealDiv", {true, {}}},
1976           {"Maximum", {true, {}}},
1977           {"Minimum", {true, {}}},
1978           {"SquaredDifference", {true, {}}},
1979           {"Select", {true, {}}},
1980           {"SparseMatMul", {true, {}}},
1981           {"BatchMatMul", {true, {}}},
1982           {"Complex", {true, {}}},
1983           {"Real", {true, {}}},
1984           {"Imag", {true, {}}},
1985           {"Angle", {true, {}}},
1986           {"Conj", {true, {}}},
1987           {"Cast", {true, {}}},
1988           {"Cross", {true, {}}},
1989           {"Cumsum", {true, {}}},
1990           {"Cumprod", {true, {}}},
1991           {"ReadVariableOp", {true, {}}},
1992           {"VarHandleOp", {true, {}}},
1993           {"Shape", {true, {}}},
1994           {"StridedSlice", {true, {}}},
1995           {"Fill", {true, {}}},
1996 
1997           // Ops that don't require a subset of outputs.
1998           {"FusedBatchNorm", {false, {0, 1, 2}}},
1999       });
2000 
2001   auto it = m->find(op_name);
2002 
2003   if (it == m->end()) return false;
2004 
2005   *output = &it->second;
2006   return true;
2007 }
2008 
2009 // Returns a pair where the first value of the pair indicates whether or not all
2010 // inputs are unused. If the first value is false, the second value is a
2011 // set that identifies which of the input indices are unused.
2012 bool OpGradientDoesntRequireInputIndices(
2013     const string& op_name,
2014     std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
2015   static tensorflow::gtl::FlatMap<
2016       string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
2017       new tensorflow::gtl::FlatMap<
2018           string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
2019           // Ops that don't require any inputs.
2020           {"Identity", {true, {}}},
2021           {"Softmax", {true, {}}},
2022           {"LogSoftmax", {true, {}}},
2023           {"BiasAdd", {true, {}}},
2024           {"Relu", {true, {}}},
2025           {"Relu6", {true, {}}},
2026           {"Elu", {true, {}}},
2027           {"Selu", {true, {}}},
2028           {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
2029           {"Neg", {true, {}}},
2030           {"Inv", {true, {}}},
2031           {"Reciprocal", {true, {}}},
2032           {"Sqrt", {true, {}}},
2033           {"Exp", {true, {}}},
2034           {"Tanh", {true, {}}},
2035           {"Sigmoid", {true, {}}},
2036           {"Real", {true, {}}},
2037           {"Imag", {true, {}}},
2038           {"Conj", {true, {}}},
2039           {"ReadVariableOp", {true, {}}},
2040           {"VarHandleOp", {true, {}}},
2041           {"Shape", {true, {}}},
2042           {"Fill", {true, {}}},
2043 
2044           // Ops that don't require a subset of inputs.
2045           {"FusedBatchNorm", {false, {2}}},
2046       });
2047 
2048   auto it = m->find(op_name);
2049 
2050   if (it == m->end()) return false;
2051 
2052   *output = &it->second;
2053   return true;
2054 }
2055 
2056 PyObject* CopySequenceSettingIndicesToNull(
2057     PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
2058   tensorflow::Safe_PyObjectPtr fast_seq(
2059       PySequence_Fast(seq, "unable to allocate"));
2060   PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get()));
2061   for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) {
2062     PyObject* item;
2063     if (indices.find(i) != indices.end()) {
2064       item = Py_None;
2065     } else {
2066       item = PySequence_Fast_GET_ITEM(fast_seq.get(), i);
2067     }
2068     Py_INCREF(item);
2069     PyTuple_SET_ITEM(result, i, item);
2070   }
2071   return result;
2072 }
2073 
2074 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
2075                          PyObject* results, PyObject* name) {
2076   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
2077   if (PyErr_Occurred()) return nullptr;
2078   std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
2079   if (PyErr_Occurred()) return nullptr;
2080 
2081   bool should_record = false;
2082   for (TFE_Py_Tape* tape : SafeTapeSet()) {
2083     if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
2084       should_record = true;
2085       break;
2086     }
2087   }
2088   if (!should_record) Py_RETURN_NONE;
2089 
2090   string c_op_name = TFE_GetPythonString(op_name);
2091 
2092   PyObject* op_outputs;
2093   bool op_outputs_tuple_created = false;
2094   std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
2095 
2096   if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
2097     if (outputs_not_required->first) {
2098       op_outputs = Py_None;
2099     } else {
2100       op_outputs_tuple_created = true;
2101       op_outputs = CopySequenceSettingIndicesToNull(
2102           results, outputs_not_required->second);
2103     }
2104   } else {
2105     op_outputs = results;
2106   }
2107 
2108   PyObject* op_inputs;
2109   bool op_inputs_tuple_created = false;
2110   std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
2111 
2112   if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
2113     if (inputs_not_required->first) {
2114       op_inputs = Py_None;
2115     } else {
2116       op_inputs_tuple_created = true;
2117       op_inputs =
2118           CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
2119     }
2120   } else {
2121     op_inputs = inputs;
2122   }
2123 
2124   PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
2125 
2126   TapeSetRecordOperation(
2127       op_name, results, input_ids, input_dtypes,
2128       [op_name, attrs, num_inputs, op_inputs, op_outputs]() {
2129         Py_INCREF(op_name);
2130         Py_INCREF(attrs);
2131         Py_INCREF(num_inputs);
2132         Py_INCREF(op_inputs);
2133         Py_INCREF(op_outputs);
2134         PyBackwardFunction* function =
2135             new PyBackwardFunction([op_name, attrs, num_inputs, op_inputs,
2136                                     op_outputs](PyObject* output_grads) {
2137               if (PyErr_Occurred()) {
2138                 return static_cast<PyObject*>(nullptr);
2139               }
2140               tensorflow::Safe_PyObjectPtr callback_args(
2141                   Py_BuildValue("OOOOOO", op_name, attrs, num_inputs, op_inputs,
2142                                 op_outputs, output_grads));
2143 
2144               tensorflow::Safe_PyObjectPtr result(
2145                   PyObject_CallObject(gradient_function, callback_args.get()));
2146 
2147               if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
2148 
2149               return tensorflow::swig::Flatten(result.get());
2150             });
2151         return function;
2152       },
2153       [op_name, attrs, num_inputs, op_inputs,
2154        op_outputs](PyBackwardFunction* backward_function) {
2155         Py_DECREF(op_name);
2156         Py_DECREF(attrs);
2157         Py_DECREF(num_inputs);
2158         Py_DECREF(op_inputs);
2159         Py_DECREF(op_outputs);
2160 
2161         delete backward_function;
2162       });
2163 
2164   Py_DECREF(num_inputs);
2165   if (op_outputs_tuple_created) Py_DECREF(op_outputs);
2166   if (op_inputs_tuple_created) Py_DECREF(op_inputs);
2167 
2168   Py_RETURN_NONE;
2169 }
2170 
2171 void MaybeNotifyVariableAccessed(PyObject* input) {
2172   DCHECK(CheckResourceVariable(input));
2173   DCHECK(PyObject_HasAttrString(input, "_trainable"));
2174 
2175   tensorflow::Safe_PyObjectPtr trainable(
2176       PyObject_GetAttrString(input, "_trainable"));
2177   if (trainable.get() == Py_False) return;
2178   TFE_Py_TapeVariableAccessed(input);
2179 }
2180 
2181 bool CastTensor(const FastPathOpExecInfo& op_exec_info,
2182                 const TF_DataType& desired_dtype,
2183                 tensorflow::Safe_TFE_TensorHandlePtr* handle,
2184                 TF_Status* status) {
2185   TF_DataType input_dtype = TFE_TensorHandleDataType(handle->get());
2186   TF_DataType output_dtype = input_dtype;
2187 
2188   if (desired_dtype >= 0 && desired_dtype != input_dtype) {
2189     *handle = tensorflow::make_safe(
2190         tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype,
2191                               static_cast<TF_DataType>(desired_dtype), status));
2192     if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2193       return false;
2194     }
2195     output_dtype = desired_dtype;
2196   }
2197 
2198   if (output_dtype != TF_INT32) {
2199     // Note that this is a shallow copy and will share the underlying buffer
2200     // if copying to the same device.
2201     *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
2202         handle->get(), op_exec_info.ctx, op_exec_info.device_name, status));
2203     if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2204       return false;
2205     }
2206   }
2207   return true;
2208 }
2209 
2210 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
2211                     PyObject* input, tensorflow::Safe_PyObjectPtr* output,
2212                     TF_Status* status) {
2213   MaybeNotifyVariableAccessed(input);
2214 
2215   TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
2216   auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
2217   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
2218 
2219   // Set dtype
2220   DCHECK(PyObject_HasAttrString(input, "_dtype"));
2221   tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
2222   int value;
2223   if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
2224     return false;
2225   }
2226   TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
2227 
2228   TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
2229   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
2230 
2231   // Get handle
2232   tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
2233   if (!EagerTensor_CheckExact(handle.get())) return false;
2234   TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
2235   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
2236 
2237   int num_retvals = 1;
2238   TFE_TensorHandle* output_handle;
2239   TFE_Execute(op, &output_handle, &num_retvals, status);
2240   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
2241 
2242   // Always create the py object (and correctly DECREF it) from the returned
2243   // value, else the data will leak.
2244   output->reset(EagerTensorFromHandle(output_handle));
2245 
2246   // TODO(nareshmodi): Should we run post exec callbacks here?
2247   if (parent_op_exec_info.run_gradient_callback) {
2248     tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
2249     PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
2250 
2251     tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
2252     Py_INCREF(output->get());  // stay alive after since tuple steals.
2253     PyTuple_SET_ITEM(outputs.get(), 0, output->get());
2254 
2255     tensorflow::Safe_PyObjectPtr op_string(
2256         GetPythonObjectFromString("ReadVariableOp"));
2257     if (!RecordGradient(op_string.get(), inputs.get(), Py_None, outputs.get(),
2258                         Py_None)) {
2259       return false;
2260     }
2261   }
2262 
2263   return true;
2264 }
2265 
2266 // Supports 3 cases at the moment:
2267 //  i) input is an EagerTensor.
2268 //  ii) input is a ResourceVariable - in this case, the is_variable param is
2269 //  set to true.
2270 //  iii) input is an arbitrary python list/tuple (note, this handling doesn't
2271 //  support packing).
2272 //
2273 //  NOTE: dtype_hint_getter must *always* return a PyObject that can be
2274 //  decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
2275 //  increfs Py_None).
2276 //
2277 //  NOTE: This function sets a python error directly, and returns false.
2278 //  TF_Status is only passed since we don't want to have to reallocate it.
2279 bool ConvertToTensor(
2280     const FastPathOpExecInfo& op_exec_info, PyObject* input,
2281     tensorflow::Safe_PyObjectPtr* output_handle,
2282     // This gets a hint for this particular input.
2283     const std::function<PyObject*()>& dtype_hint_getter,
2284     // This sets the dtype after conversion is complete.
2285     const std::function<void(const TF_DataType& dtype)>& dtype_setter,
2286     TF_Status* status) {
2287   if (EagerTensor_CheckExact(input)) {
2288     Py_INCREF(input);
2289     output_handle->reset(input);
2290     return true;
2291   } else if (CheckResourceVariable(input)) {
2292     return ReadVariableOp(op_exec_info, input, output_handle, status);
2293   }
2294 
2295   // The hint comes from a supposedly similarly typed tensor.
2296   tensorflow::Safe_PyObjectPtr dtype_hint(dtype_hint_getter());
2297   if (PyErr_Occurred()) {
2298     return false;
2299   }
2300 
2301   tensorflow::Safe_TFE_TensorHandlePtr handle =
2302       tensorflow::make_safe(static_cast<TFE_TensorHandle*>(
2303           tensorflow::ConvertToEagerTensor(input, dtype_hint.get())));
2304   if (handle == nullptr) {
2305     return MaybeRaiseExceptionFromTFStatus(status, nullptr);
2306   }
2307 
2308   int desired_dtype = -1;
2309   if (dtype_hint.get() != Py_None) {
2310     if (!ParseTypeValue("", dtype_hint.get(), status, &desired_dtype)) {
2311       PyErr_SetString(PyExc_TypeError,
2312                       tensorflow::strings::StrCat(
2313                           "Expecting a DataType value for dtype. Got ",
2314                           Py_TYPE(dtype_hint.get())->tp_name)
2315                           .c_str());
2316       return false;
2317     }
2318   }
2319 
2320   // Maybe cast to the desired type. This is intended to match python
2321   // convert_to_tensor behavior.
2322   TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get());
2323   if (desired_dtype >= 0 && desired_dtype != output_dtype) {
2324     if (tensorflow::IsCompatible(desired_dtype, output_dtype)) {
2325       if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype),
2326                       &handle, status)) {
2327         return false;
2328       }
2329       output_dtype = TFE_TensorHandleDataType(handle.get());
2330     } else {
2331       tensorflow::Safe_PyObjectPtr input_str(PyObject_Str(input));
2332       PyErr_SetString(
2333           PyExc_TypeError,
2334           tensorflow::strings::StrCat(
2335               "Cannot convert provided value to EagerTensor. Provided value: ",
2336               TFE_GetPythonString(input_str.get()), " Requested dtype: ",
2337               tensorflow::DataTypeString(
2338                   static_cast<tensorflow::DataType>(desired_dtype)))
2339               .c_str());
2340       return false;
2341     }
2342   }
2343 
2344   output_handle->reset(EagerTensorFromHandle(handle.release()));
2345   dtype_setter(output_dtype);
2346 
2347   return true;
2348 }
2349 
2350 // Adds input and type attr to the op, and to the list of flattened
2351 // inputs/attrs.
2352 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
2353                   const bool add_type_attr,
2354                   const tensorflow::OpDef::ArgDef& input_arg,
2355                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
2356                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
2357                   TFE_Op* op, TF_Status* status) {
2358   // py_eager_tensor's ownership is transferred to flattened_inputs if it is
2359   // required, else the object is destroyed and DECREF'd when the object goes
2360   // out of scope in this function.
2361   tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
2362 
2363   if (!ConvertToTensor(
2364           *op_exec_info, input, &py_eager_tensor,
2365           [&]() {
2366             if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
2367               return GetPythonObjectFromInt(input_arg.type());
2368             }
2369             return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
2370           },
2371           [&](const TF_DataType dtype) {
2372             op_exec_info->cached_dtypes[input_arg.type_attr()] =
2373                 static_cast<tensorflow::DataType>(dtype);
2374           },
2375           status)) {
2376     return false;
2377   }
2378 
2379   TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
2380 
2381   if (add_type_attr && !input_arg.type_attr().empty()) {
2382     auto dtype = TFE_TensorHandleDataType(input_handle);
2383     TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
2384     if (flattened_attrs != nullptr) {
2385       flattened_attrs->emplace_back(
2386           GetPythonObjectFromString(input_arg.type_attr().data()));
2387       flattened_attrs->emplace_back(PyLong_FromLong(dtype));
2388     }
2389   }
2390 
2391   if (flattened_inputs != nullptr) {
2392     flattened_inputs->emplace_back(std::move(py_eager_tensor));
2393   }
2394 
2395   TFE_OpAddInput(op, input_handle, status);
2396   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2397     return false;
2398   }
2399 
2400   return true;
2401 }
2402 
2403 const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) {
2404   const char* op_name = TFE_GetPythonString(py_op_name);
2405   if (op_name == nullptr) {
2406     PyErr_SetString(PyExc_TypeError,
2407                     Printf("expected a string for op_name, got %s instead",
2408                            py_op_name->ob_type->tp_name)
2409                         .c_str());
2410     return nullptr;
2411   }
2412 
2413   const tensorflow::OpRegistrationData* op_reg_data = nullptr;
2414   const tensorflow::Status lookup_status =
2415       tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
2416   if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) {
2417     return nullptr;
2418   }
2419   return &op_reg_data->op_def;
2420 }
2421 
2422 const char* GetDeviceName(PyObject* py_device_name) {
2423   if (py_device_name != Py_None) {
2424     return TFE_GetPythonString(py_device_name);
2425   }
2426   return nullptr;
2427 }
2428 
2429 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
2430   if (!PySequence_Check(seq)) {
2431     PyErr_SetString(PyExc_TypeError,
2432                     Printf("expected a sequence for attr %s, got %s instead",
2433                            attr_name.data(), seq->ob_type->tp_name)
2434                         .data());
2435 
2436     return false;
2437   }
2438   return true;
2439 }
2440 
2441 bool RunCallbacks(
2442     const FastPathOpExecInfo& op_exec_info, PyObject* args,
2443     const std::vector<tensorflow::Safe_PyObjectPtr>* const flattened_inputs,
2444     const std::vector<tensorflow::Safe_PyObjectPtr>* const flattened_attrs,
2445     PyObject* flattened_result) {
2446   if (!op_exec_info.run_callbacks) return true;
2447 
2448   tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs->size()));
2449   for (int i = 0; i < flattened_inputs->size(); i++) {
2450     PyObject* input = (*flattened_inputs)[i].get();
2451     Py_INCREF(input);
2452     PyTuple_SET_ITEM(inputs.get(), i, input);
2453   }
2454 
2455   int num_non_inferred_attrs = PyTuple_GET_SIZE(args) -
2456                                op_exec_info.op_def->input_arg_size() -
2457                                kFastPathExecuteInputStartIndex;
2458   int num_attrs = flattened_attrs->size() + num_non_inferred_attrs;
2459   tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
2460 
2461   for (int i = 0; i < num_non_inferred_attrs; i++) {
2462     auto* attr =
2463         PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex +
2464                                    op_exec_info.op_def->input_arg_size() + i);
2465     Py_INCREF(attr);
2466     PyTuple_SET_ITEM(attrs.get(), i, attr);
2467   }
2468   for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
2469     PyObject* attr_or_name =
2470         flattened_attrs->at(i - num_non_inferred_attrs).get();
2471     Py_INCREF(attr_or_name);
2472     PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
2473   }
2474 
2475   if (op_exec_info.run_gradient_callback) {
2476     if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
2477                         flattened_result, op_exec_info.name)) {
2478       return false;
2479     }
2480   }
2481 
2482   if (op_exec_info.run_post_exec_callbacks) {
2483     tensorflow::Safe_PyObjectPtr callback_args(
2484         Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
2485                       flattened_result, op_exec_info.name));
2486     for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
2487       PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
2488       if (!PyCallable_Check(callback_fn)) {
2489         PyErr_SetString(
2490             PyExc_TypeError,
2491             Printf("expected a function for "
2492                    "post execution callback in index %ld, got %s instead",
2493                    i, callback_fn->ob_type->tp_name)
2494                 .c_str());
2495         return false;
2496       }
2497       PyObject* callback_result =
2498           PyObject_CallObject(callback_fn, callback_args.get());
2499       if (!callback_result) {
2500         return false;
2501       }
2502       Py_DECREF(callback_result);
2503     }
2504   }
2505 
2506   return true;
2507 }
2508 
2509 }  // namespace
2510 
2511 PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
2512   Py_ssize_t args_size = PyTuple_GET_SIZE(args);
2513   if (args_size < kFastPathExecuteInputStartIndex) {
2514     PyErr_SetString(
2515         PyExc_ValueError,
2516         Printf("There must be at least %d items in the input tuple.",
2517                kFastPathExecuteInputStartIndex)
2518             .c_str());
2519     return nullptr;
2520   }
2521 
2522   FastPathOpExecInfo op_exec_info;
2523 
2524   op_exec_info.ctx = reinterpret_cast<TFE_Context*>(
2525       PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
2526   op_exec_info.args = args;
2527 
2528   if (op_exec_info.ctx == nullptr) {
2529     // The context hasn't been initialized. It will be in the slow path.
2530     RaiseFallbackException(
2531         "This function does not handle the case of the path where "
2532         "all inputs are not already EagerTensors.");
2533     return nullptr;
2534   }
2535 
2536   op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
2537   op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
2538   op_exec_info.op_def = GetOpDef(op_exec_info.op_name);
2539   if (op_exec_info.op_def == nullptr) return nullptr;
2540   op_exec_info.name = PyTuple_GET_ITEM(args, 3);
2541   op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4);
2542 
2543   const tensorflow::OpDef* op_def = op_exec_info.op_def;
2544 
2545   // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
2546   // (similar to benchmark_tf_gradient_function_*). Also consider using an
2547   // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
2548   // point out problems with heap allocs.
2549   op_exec_info.run_gradient_callback =
2550       !*ThreadTapeIsStopped() && !GetTapeSet()->empty();
2551   op_exec_info.run_post_exec_callbacks =
2552       op_exec_info.callbacks != Py_None &&
2553       PyList_Size(op_exec_info.callbacks) > 0;
2554   op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
2555                                op_exec_info.run_post_exec_callbacks;
2556 
2557   if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
2558     PyErr_SetString(
2559         PyExc_ValueError,
2560         Printf("Tuple size smaller than intended. Expected to be at least %d, "
2561                "was %ld",
2562                kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
2563                args_size)
2564             .c_str());
2565     return nullptr;
2566   }
2567 
2568   if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) {
2569     RaiseFallbackException(
2570         "This function does not handle the case of the path where "
2571         "all inputs are not already EagerTensors.");
2572     return nullptr;
2573   }
2574 
2575   op_exec_info.attr_to_inputs_map = GetAttrToInputsMap(*op_def);
2576 
2577   TF_Status* status = TF_NewStatus();
2578   TFE_Op* op = TFE_NewOp(op_exec_info.ctx, op_def->name().c_str(), status);
2579   auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
2580     TF_DeleteStatus(status);
2581     TFE_DeleteOp(op);
2582   });
2583   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2584     return nullptr;
2585   }
2586 
2587   // Mapping of attr name to size - used to calculate the number of values
2588   // to be expected by the TFE_Execute run.
2589   tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
2590 
2591   // Set non-inferred attrs, including setting defaults if the attr is passed in
2592   // as None.
2593   for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
2594        i < args_size; i += 2) {
2595     PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
2596     const tensorflow::StringPiece attr_name(TFE_GetPythonString(py_attr_name));
2597     PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
2598 
2599     // Not creating an index since most of the time there are not more than a
2600     // few attrs.
2601     // TODO(nareshmodi): Maybe include the index as part of the
2602     // OpRegistrationData.
2603     for (const auto& attr : op_def->attr()) {
2604       if (attr_name == attr.name()) {
2605         SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr_name.data(),
2606                               py_attr_value, &attr_list_sizes, status);
2607 
2608         if (TF_GetCode(status) != TF_OK) {
2609           VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
2610                   << "\" since we are unable to set the value for attr \""
2611                   << attr.name() << "\" due to: " << TF_Message(status);
2612           RaiseFallbackException(TF_Message(status));
2613           return nullptr;
2614         }
2615 
2616         break;
2617       }
2618     }
2619   }
2620 
2621   TFE_OpSetDevice(op, op_exec_info.device_name, status);
2622   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2623     return nullptr;
2624   }
2625 
2626   // Flat attrs and inputs as required by the record_gradient call. The attrs
2627   // here only contain inferred attrs (non-inferred attrs are added directly
2628   // from the input args).
2629   // All items in flattened_attrs and flattened_inputs contain
2630   // Safe_PyObjectPtr - any time something steals a reference to this, it must
2631   // INCREF.
2632   // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
2633   // directly.
2634   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
2635       nullptr;
2636   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
2637       nullptr;
2638 
2639   // TODO(nareshmodi): Encapsulate callbacks information into a struct.
2640   if (op_exec_info.run_callbacks) {
2641     flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
2642     flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
2643   }
2644 
2645   // Add inferred attrs and inputs.
2646   // The following code might set duplicate type attrs. This will result in
2647   // the CacheKey for the generated AttrBuilder possibly differing from
2648   // those where the type attrs are correctly set. Inconsistent CacheKeys
2649   // for ops means that there might be unnecessarily duplicated kernels.
2650   // TODO(nareshmodi): Fix this.
2651   for (int i = 0; i < op_def->input_arg_size(); i++) {
2652     const auto& input_arg = op_def->input_arg(i);
2653 
2654     PyObject* input =
2655         PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
2656     if (!input_arg.number_attr().empty()) {
2657       // The item is a homogeneous list.
2658       if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
2659       tensorflow::Safe_PyObjectPtr fast_input(
2660           PySequence_Fast(input, "Could not parse sequence."));
2661       if (fast_input.get() == nullptr) {
2662         return nullptr;
2663       }
2664       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
2665 
2666       TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
2667       if (op_exec_info.run_callbacks) {
2668         flattened_attrs->emplace_back(
2669             GetPythonObjectFromString(input_arg.number_attr().data()));
2670         flattened_attrs->emplace_back(PyLong_FromLong(len));
2671       }
2672       attr_list_sizes[input_arg.number_attr()] = len;
2673 
2674       if (len > 0) {
2675         // First item adds the type attr.
2676         if (!AddInputToOp(&op_exec_info,
2677                           PySequence_Fast_GET_ITEM(fast_input.get(), 0), true,
2678                           input_arg, flattened_attrs.get(),
2679                           flattened_inputs.get(), op, status)) {
2680           return nullptr;
2681         }
2682 
2683         for (Py_ssize_t j = 1; j < len; j++) {
2684           // Since the list is homogeneous, we don't need to re-add the attr.
2685           if (!AddInputToOp(&op_exec_info,
2686                             PySequence_Fast_GET_ITEM(fast_input.get(), j),
2687                             false, input_arg, nullptr /* flattened_attrs */,
2688                             flattened_inputs.get(), op, status)) {
2689             return nullptr;
2690           }
2691         }
2692       }
2693     } else if (!input_arg.type_list_attr().empty()) {
2694       // The item is a heterogeneous list.
2695       if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
2696         return nullptr;
2697       }
2698       const string& attr_name = input_arg.type_list_attr();
2699       Py_ssize_t len = PySequence_Fast_GET_SIZE(input);
2700       tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
2701       PyObject* py_attr_value = nullptr;
2702       if (op_exec_info.run_callbacks) {
2703         py_attr_value = PyTuple_New(len);
2704       }
2705       for (Py_ssize_t j = 0; j < len; j++) {
2706         PyObject* py_input = PySequence_Fast_GET_ITEM(input, j);
2707         tensorflow::Safe_PyObjectPtr py_eager_tensor;
2708         if (!ConvertToTensor(
2709                 op_exec_info, py_input, &py_eager_tensor,
2710                 []() { Py_RETURN_NONE; }, [](const TF_DataType& dtype) {},
2711                 status)) {
2712           return nullptr;
2713         }
2714 
2715         TFE_TensorHandle* input_handle =
2716             EagerTensor_Handle(py_eager_tensor.get());
2717 
2718         attr_value[j] = TFE_TensorHandleDataType(input_handle);
2719 
2720         TFE_OpAddInput(op, input_handle, status);
2721         if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2722           return nullptr;
2723         }
2724 
2725         if (op_exec_info.run_callbacks) {
2726           flattened_inputs->emplace_back(std::move(py_eager_tensor));
2727 
2728           PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
2729         }
2730       }
2731       if (op_exec_info.run_callbacks) {
2732         flattened_attrs->emplace_back(
2733             GetPythonObjectFromString(attr_name.data()));
2734         flattened_attrs->emplace_back(py_attr_value);
2735       }
2736       TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
2737                             attr_value.size());
2738       attr_list_sizes[attr_name] = len;
2739     } else {
2740       // The item is a single item.
2741       if (!AddInputToOp(&op_exec_info, input, true, input_arg,
2742                         flattened_attrs.get(), flattened_inputs.get(), op,
2743                         status)) {
2744         return nullptr;
2745       }
2746     }
2747   }
2748 
2749   int num_retvals = 0;
2750   for (int i = 0; i < op_def->output_arg_size(); i++) {
2751     const auto& output_arg = op_def->output_arg(i);
2752     int delta = 1;
2753     if (!output_arg.number_attr().empty()) {
2754       delta = attr_list_sizes[output_arg.number_attr()];
2755     } else if (!output_arg.type_list_attr().empty()) {
2756       delta = attr_list_sizes[output_arg.type_list_attr()];
2757     }
2758     if (delta < 0) {
2759       RaiseFallbackException(
2760           "Attributes suggest that the size of an output list is less than 0");
2761       return nullptr;
2762     }
2763     num_retvals += delta;
2764   }
2765 
2766   tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
2767 
2768   Py_BEGIN_ALLOW_THREADS;
2769   TFE_Execute(op, retvals.data(), &num_retvals, status);
2770   Py_END_ALLOW_THREADS;
2771 
2772   if (TF_GetCode(status) != TF_OK) {
2773     // Augment the status with the op_name for easier debugging similar to
2774     // TFE_Py_Execute.
2775     TF_SetStatus(status, TF_GetCode(status),
2776                  tensorflow::strings::StrCat(
2777                      TF_Message(status),
2778                      " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]")
2779                      .c_str());
2780 
2781     MaybeRaiseExceptionFromTFStatus(status, nullptr);
2782     return nullptr;
2783   }
2784 
2785   tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
2786   for (int i = 0; i < num_retvals; ++i) {
2787     PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
2788   }
2789 
2790   if (!RunCallbacks(op_exec_info, args, flattened_inputs.get(),
2791                     flattened_attrs.get(), flat_result.get())) {
2792     return nullptr;
2793   }
2794 
2795   // Unflatten results.
2796   if (op_def->output_arg_size() == 0) {
2797     Py_RETURN_NONE;
2798   }
2799 
2800   if (op_def->output_arg_size() == 1) {
2801     if (!op_def->output_arg(0).number_attr().empty() ||
2802         !op_def->output_arg(0).type_list_attr().empty()) {
2803       return flat_result.release();
2804     } else {
2805       auto* result = PyList_GET_ITEM(flat_result.get(), 0);
2806       Py_INCREF(result);
2807       return result;
2808     }
2809   }
2810 
2811   // Correctly output the results that are made into a namedtuple.
2812   PyObject* result = PyList_New(op_def->output_arg_size());
2813   int flat_result_index = 0;
2814   for (int i = 0; i < op_def->output_arg_size(); i++) {
2815     if (!op_def->output_arg(i).number_attr().empty()) {
2816       int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
2817       PyObject* inner_list = PyList_New(list_length);
2818       for (int j = 0; j < list_length; j++) {
2819         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
2820         Py_INCREF(obj);
2821         PyList_SET_ITEM(inner_list, j, obj);
2822       }
2823       PyList_SET_ITEM(result, i, inner_list);
2824     } else if (!op_def->output_arg(i).type_list_attr().empty()) {
2825       int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
2826       PyObject* inner_list = PyList_New(list_length);
2827       for (int j = 0; j < list_length; j++) {
2828         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
2829         Py_INCREF(obj);
2830         PyList_SET_ITEM(inner_list, j, obj);
2831       }
2832       PyList_SET_ITEM(result, i, inner_list);
2833     } else {
2834       PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
2835       Py_INCREF(obj);
2836       PyList_SET_ITEM(result, i, obj);
2837     }
2838   }
2839   return result;
2840 }
2841 
2842 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
2843                                 PyObject* attrs, PyObject* results,
2844                                 PyObject* name) {
2845   if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
2846     Py_RETURN_NONE;
2847   }
2848 
2849   return RecordGradient(op_name, inputs, attrs, results, name);
2850 }
2851 
2852 namespace {
2853 const char kTensor[] = "T";
2854 const char kIndexedSlices[] = "I";
2855 const char kList[] = "L";
2856 const char kListEnd[] = "l";
2857 const char kTuple[] = "U";
2858 const char kTupleEnd[] = "u";
2859 const char kDict[] = "D";
2860 const char kRaw[] = "R";
2861 const char kShape[] = "s";
2862 const char kShapeDelim[] = "-";
2863 const char kDType[] = "d";
2864 const char kNone[] = "n";
2865 
2866 struct EncodeResult {
2867   string str;
2868   std::vector<PyObject*> objects;
2869 
2870   PyObject* ToPyTuple() {
2871     PyObject* result = PyTuple_New(2);
2872 
2873     PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str.c_str()));
2874 
2875     if (objects.empty()) {
2876       Py_INCREF(Py_None);
2877       PyTuple_SET_ITEM(result, 1, Py_None);
2878     } else {
2879       PyObject* objects_tuple = PyTuple_New(objects.size());
2880 
2881       for (int i = 0; i < objects.size(); i++) {
2882         PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
2883       }
2884 
2885       PyTuple_SET_ITEM(result, 1, objects_tuple);
2886     }
2887 
2888     return result;
2889   }
2890 };
2891 
2892 tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
2893                                        bool include_tensor_ranks_only,
2894                                        EncodeResult* result) {
2895   if (EagerTensor_CheckExact(arg)) {
2896     TFE_TensorHandle* t = EagerTensor_Handle(arg);
2897     tensorflow::TensorShape tensor_shape;
2898     TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
2899 
2900     absl::StrAppend(&result->str, kDType, t->handle->dtype);
2901 
2902     absl::StrAppend(&result->str, kShape);
2903     if (include_tensor_ranks_only) {
2904       absl::StrAppend(&result->str, tensor_shape.dim_sizes().size());
2905     } else {
2906       for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
2907         absl::StrAppend(&result->str, dim_size, kShapeDelim);
2908       }
2909     }
2910     return tensorflow::Status::OK();
2911   }
2912 
2913   tensorflow::Safe_PyObjectPtr dtype_object(
2914       PyObject_GetAttrString(arg, "dtype"));
2915 
2916   if (dtype_object == nullptr) {
2917     return tensorflow::errors::InvalidArgument(
2918         "ops.Tensor object doesn't have dtype() attr.");
2919   }
2920 
2921   tensorflow::Safe_PyObjectPtr dtype_enum(
2922       PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
2923 
2924   if (dtype_enum == nullptr) {
2925     return tensorflow::errors::InvalidArgument(
2926         "ops.Tensor's dtype object doesn't have _type_enum() attr.");
2927   }
2928 
2929   tensorflow::DataType dtype =
2930       static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
2931 
2932   absl::StrAppend(&result->str, kDType, dtype);
2933 
2934   static char _shape_tuple[] = "_shape_tuple";
2935   tensorflow::Safe_PyObjectPtr shape_tuple(
2936       PyObject_CallMethod(arg, _shape_tuple, nullptr));
2937 
2938   if (shape_tuple == nullptr) {
2939     return tensorflow::errors::InvalidArgument(
2940         "ops.Tensor object doesn't have _shape_tuple() method.");
2941   }
2942 
2943   if (shape_tuple.get() == Py_None) {
2944     // Unknown shape, encode that directly.
2945     absl::StrAppend(&result->str, kNone);
2946     return tensorflow::Status::OK();
2947   }
2948 
2949   absl::StrAppend(&result->str, kShape);
2950   tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
2951       shape_tuple.get(), "shape_tuple didn't return a sequence"));
2952 
2953   int len = PySequence_Fast_GET_SIZE(shape_seq.get());
2954 
2955   if (include_tensor_ranks_only) {
2956     absl::StrAppend(&result->str, len);
2957   } else {
2958     for (int i = 0; i < len; ++i) {
2959       PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
2960       if (item == Py_None) {
2961         absl::StrAppend(&result->str, kNone);
2962       } else {
2963         absl::StrAppend(&result->str, MakeInt(item));
2964       }
2965     }
2966   }
2967   return tensorflow::Status::OK();
2968 }
2969 
2970 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
2971                                           bool include_tensor_ranks_only,
2972                                           EncodeResult* result);
2973 
2974 // This function doesn't set the type of sequence before
2975 tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
2976                                          const char* end_type,
2977                                          bool include_tensor_ranks_only,
2978                                          EncodeResult* result) {
2979   tensorflow::Safe_PyObjectPtr arg_seq(
2980       PySequence_Fast(arg, "unable to create seq from list/tuple"));
2981 
2982   absl::StrAppend(&result->str, type);
2983   int len = PySequence_Fast_GET_SIZE(arg_seq.get());
2984   for (int i = 0; i < len; ++i) {
2985     PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i);
2986     if (item == Py_None) {
2987       absl::StrAppend(&result->str, kNone);
2988     } else {
2989       TF_RETURN_IF_ERROR(
2990           TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
2991     }
2992   }
2993   absl::StrAppend(&result->str, end_type);
2994 
2995   return tensorflow::Status::OK();
2996 }
2997 
2998 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
2999                                           bool include_tensor_ranks_only,
3000                                           EncodeResult* result) {
3001   if (tensorflow::swig::IsTensor(arg)) {
3002     absl::StrAppend(&result->str, kTensor);
3003     TF_RETURN_IF_ERROR(
3004         TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
3005   } else if (tensorflow::swig::IsIndexedSlices(arg)) {
3006     absl::StrAppend(&result->str, kIndexedSlices);
3007     tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values"));
3008     if (values == nullptr) {
3009       PyErr_Clear();
3010       return tensorflow::errors::InvalidArgument(
3011           "IndexedSlices does not have a values attr");
3012     }
3013     TF_RETURN_IF_ERROR(
3014         TFE_Py_EncodeTensor(values.get(), include_tensor_ranks_only, result));
3015 
3016     tensorflow::Safe_PyObjectPtr indices(
3017         PyObject_GetAttrString(arg, "indices"));
3018     if (indices == nullptr) {
3019       PyErr_Clear();
3020       return tensorflow::errors::InvalidArgument(
3021           "IndexedSlices does not have a indices attr");
3022     }
3023     TF_RETURN_IF_ERROR(
3024         TFE_Py_EncodeTensor(indices.get(), include_tensor_ranks_only, result));
3025 
3026     tensorflow::Safe_PyObjectPtr dense_shape(
3027         PyObject_GetAttrString(arg, "dense_shape"));
3028     if (dense_shape == nullptr) {
3029       PyErr_Clear();
3030       return tensorflow::errors::InvalidArgument(
3031           "IndexedSlices does not have a dense_shape attr");
3032     }
3033     if (dense_shape.get() != Py_None) {
3034       TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(
3035           dense_shape.get(), include_tensor_ranks_only, result));
3036     }
3037   } else if (PyList_Check(arg)) {
3038     TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
3039         arg, kList, kListEnd, include_tensor_ranks_only, result));
3040   } else if (PyTuple_Check(arg)) {
3041     TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
3042         arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
3043   } else if (PyDict_Check(arg)) {
3044     tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
3045     if (PyList_Sort(keys.get()) == -1) {
3046       return tensorflow::errors::Internal("Unable to sort keys");
3047     }
3048 
3049     absl::StrAppend(&result->str, kDict);
3050     int len = PyList_Size(keys.get());
3051 
3052     for (int i = 0; i < len; i++) {
3053       PyObject* key = PyList_GetItem(keys.get(), i);
3054       TF_RETURN_IF_ERROR(
3055           TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
3056       PyObject* value = PyDict_GetItem(arg, key);
3057       TF_RETURN_IF_ERROR(
3058           TFE_Py_EncodeArgHelper(value, include_tensor_ranks_only, result));
3059     }
3060   } else {
3061     PyObject* object = PyWeakref_NewRef(arg, nullptr);
3062 
3063     if (object == nullptr) {
3064       PyErr_Clear();
3065 
3066       object = arg;
3067       Py_INCREF(object);
3068     }
3069 
3070     absl::StrAppend(&result->str, kRaw);
3071     result->objects.push_back(object);
3072   }
3073 
3074   return tensorflow::Status::OK();
3075 }
3076 
3077 }  // namespace
3078 
3079 // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
3080 // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
3081 // are used for both performance reasons, as much TensorFlow code specializes
3082 // on known shapes to produce slimmer graphs, and correctness, as some
3083 // high-level APIs require shapes to be fully-known.
3084 //
3085 // `include_tensor_ranks_only` allows caching on arguments excluding shape info,
3086 // so that a slow path using relaxed shape can rely on a cache key that excludes
3087 // shapes.
3088 //
3089 // TODO(nareshmodi): Add support for sparse tensors.
3090 PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
3091   EncodeResult result;
3092   const auto status =
3093       TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
3094   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
3095     return nullptr;
3096   }
3097 
3098   return result.ToPyTuple();
3099 }
3100 
3101 // A method prints incoming messages directly to Python's
3102 // stdout using Python's C API. This is necessary in Jupyter notebooks
3103 // and colabs where messages to the C stdout don't go to the notebook
3104 // cell outputs, but calls to Python's stdout do.
3105 void PrintToPythonStdout(const char* msg) {
3106   if (Py_IsInitialized()) {
3107     PyGILState_STATE py_threadstate;
3108     py_threadstate = PyGILState_Ensure();
3109 
3110     string string_msg = msg;
3111     // PySys_WriteStdout truncates strings over 1000 bytes, so
3112     // we write the message in chunks small enough to not be truncated.
3113     int CHUNK_SIZE = 900;
3114     auto len = string_msg.length();
3115     for (int i = 0; i < len; i += CHUNK_SIZE) {
3116       PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
3117     }
3118     PySys_WriteStdout("\n");
3119 
3120     PyGILState_Release(py_threadstate);
3121   }
3122 }
3123 
3124 // Register PrintToPythonStdout as a log listener, to allow
3125 // printing in colabs and jupyter notebooks to work.
3126 void TFE_Py_EnableInteractivePythonLogging() {
3127   static bool enabled_interactive_logging = false;
3128   if (!enabled_interactive_logging) {
3129     enabled_interactive_logging = true;
3130     TF_RegisterLogListener(PrintToPythonStdout);
3131   }
3132 }
3133