• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 #include <cstring>
18 #include <unordered_map>
19 
20 #include "absl/debugging/leak_check.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/c/c_api.h"
24 #include "tensorflow/c/c_api_internal.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/c/eager/tape.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_op_internal.h"
30 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
31 #include "tensorflow/c/tf_status.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/gtl/compactptrset.h"
36 #include "tensorflow/core/lib/gtl/flatmap.h"
37 #include "tensorflow/core/lib/gtl/flatset.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/lib/strings/stringprintf.h"
40 #include "tensorflow/core/platform/casts.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/status.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/util/managed_stack_trace.h"
47 #include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
48 #include "tensorflow/python/eager/pywrap_tensor.h"
49 #include "tensorflow/python/eager/pywrap_tfe.h"
50 #include "tensorflow/python/lib/core/py_util.h"
51 #include "tensorflow/python/lib/core/safe_ptr.h"
52 #include "tensorflow/python/util/stack_trace.h"
53 #include "tensorflow/python/util/util.h"
54 
55 using tensorflow::Status;
56 using tensorflow::string;
57 using tensorflow::strings::Printf;
58 
59 namespace {
60 // NOTE: Items are retrieved from and returned to these unique_ptrs, and they
61 // act as arenas. This is important if the same thread requests 2 items without
62 // releasing one.
63 // The following sequence of events on the same thread will still succeed:
64 // - GetOp <- Returns existing.
65 // - GetOp <- Allocates and returns a new pointer.
66 // - ReleaseOp <- Sets the item in the unique_ptr.
67 // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
68 // This occurs when a PyFunc kernel is run. This behavior makes it safe in that
69 // case, as well as the case where python decides to reuse the underlying
70 // C++ thread in 2 python threads case.
71 struct OpDeleter {
operator ()__anonb044ee210111::OpDeleter72   void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
73 };
74 thread_local std::unordered_map<TFE_Context*,
75                                 std::unique_ptr<TFE_Op, OpDeleter>>
76     thread_local_eager_operation_map;                             // NOLINT
77 thread_local std::unique_ptr<TF_Status> thread_local_tf_status =  // NOLINT
78     nullptr;
79 
ReleaseThreadLocalOp(TFE_Context * ctx)80 std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
81   auto it = thread_local_eager_operation_map.find(ctx);
82   if (it == thread_local_eager_operation_map.end()) {
83     return nullptr;
84   }
85   return std::move(it->second);
86 }
87 
GetOp(TFE_Context * ctx,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)88 TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
89               const char* raw_device_name, TF_Status* status) {
90   auto op = ReleaseThreadLocalOp(ctx);
91   if (!op) {
92     op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
93   }
94   status->status =
95       tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
96   if (!status->status.ok()) {
97     op.reset();
98   }
99   return op.release();
100 }
101 
ReturnOp(TFE_Context * ctx,TFE_Op * op)102 void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
103   if (op) {
104     tensorflow::unwrap(op)->Clear();
105     thread_local_eager_operation_map[ctx].reset(op);
106   }
107 }
108 
ReleaseThreadLocalStatus()109 TF_Status* ReleaseThreadLocalStatus() {
110   if (thread_local_tf_status == nullptr) {
111     return nullptr;
112   }
113   return thread_local_tf_status.release();
114 }
115 
116 struct InputInfo {
InputInfo__anonb044ee210111::InputInfo117   InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
118 
119   int i;
120   bool is_list = false;
121 };
122 
123 // Takes in output gradients, returns input gradients.
124 typedef std::function<PyObject*(PyObject*,
125                                 const std::vector<tensorflow::int64>&)>
126     PyBackwardFunction;
127 
128 using AttrToInputsMap =
129     tensorflow::gtl::FlatMap<string,
130                              tensorflow::gtl::InlinedVector<InputInfo, 4>>;
131 
GetAllAttrToInputsMaps()132 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
133   static auto* all_attr_to_input_maps =
134       new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
135   return all_attr_to_input_maps;
136 }
137 
138 // This function doesn't use a lock, since we depend on the GIL directly.
GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef & op_def)139 AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
140 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
141   DCHECK(PyGILState_Check())
142       << "This function needs to hold the GIL when called.";
143 #endif
144   auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
145   auto* output =
146       tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
147   if (output != nullptr) {
148     return output;
149   }
150 
151   std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
152 
153   // Store a list of InputIndex -> List of corresponding inputs.
154   for (int i = 0; i < op_def.input_arg_size(); i++) {
155     if (!op_def.input_arg(i).type_attr().empty()) {
156       auto it = m->find(op_def.input_arg(i).type_attr());
157       if (it == m->end()) {
158         it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
159       }
160       it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
161     }
162   }
163 
164   auto* retval = m.get();
165   (*all_attr_to_input_maps)[op_def.name()] = m.release();
166 
167   return retval;
168 }
169 
170 // This function doesn't use a lock, since we depend on the GIL directly.
171 tensorflow::gtl::FlatMap<
172     string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
GetAllAttrToDefaultsMaps()173 GetAllAttrToDefaultsMaps() {
174   static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
175       string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
176   return all_attr_to_defaults_maps;
177 }
178 
179 tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef & op_def)180 GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
181 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
182   DCHECK(PyGILState_Check())
183       << "This function needs to hold the GIL when called.";
184 #endif
185   auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
186   auto* output =
187       tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
188   if (output != nullptr) {
189     return output;
190   }
191 
192   auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
193 
194   for (const auto& attr : op_def.attr()) {
195     if (attr.type() == "type" && attr.has_default_value()) {
196       new_map->insert({attr.name(), attr.default_value().type()});
197     }
198   }
199 
200   (*all_attr_to_defaults_maps)[op_def.name()] = new_map;
201 
202   return new_map;
203 }
204 
205 struct FastPathOpExecInfo {
206   TFE_Context* ctx;
207   const char* device_name;
208 
209   bool run_callbacks;
210   bool run_post_exec_callbacks;
211   bool run_gradient_callback;
212 
213   // The op name of the main op being executed.
214   PyObject* name;
215   // The op type name of the main op being executed.
216   PyObject* op_name;
217   PyObject* callbacks;
218 
219   // All the args passed into the FastPathOpExecInfo.
220   PyObject* args;
221 
222   // DTypes can come from another input that has the same attr. So build that
223   // map.
224   const AttrToInputsMap* attr_to_inputs_map;
225   const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
226   tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
227 };
228 
229 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn)                       \
230   bool fn_name(const string& key, PyObject* py_value, TF_Status* status,     \
231                type* value) {                                                \
232     if (check_fn(py_value)) {                                                \
233       *value = static_cast<type>(parse_fn(py_value));                        \
234       return true;                                                           \
235     } else {                                                                 \
236       TF_SetStatus(status, TF_INVALID_ARGUMENT,                              \
237                    tensorflow::strings::StrCat(                              \
238                        "Expecting " #type " value for attr ", key, ", got ", \
239                        py_value->ob_type->tp_name)                           \
240                        .c_str());                                            \
241       return false;                                                          \
242     }                                                                        \
243   }
244 
245 #if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue,int,PyLong_Check,PyLong_AsLong)246 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
247 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
248 #else
249 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
250 #endif
251 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
252 #undef PARSE_VALUE
253 
254 #if PY_MAJOR_VERSION < 3
255 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
256                      int64_t* value) {
257   if (PyInt_Check(py_value)) {
258     *value = static_cast<int64_t>(PyInt_AsLong(py_value));
259     return true;
260   } else if (PyLong_Check(py_value)) {
261     *value = static_cast<int64_t>(PyLong_AsLong(py_value));
262     return true;
263   }
264   TF_SetStatus(
265       status, TF_INVALID_ARGUMENT,
266       tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
267                                   ", got ", py_value->ob_type->tp_name)
268           .c_str());
269   return false;
270 }
271 #endif
272 
TensorShapeNumDims(PyObject * value)273 Py_ssize_t TensorShapeNumDims(PyObject* value) {
274   const auto size = PySequence_Size(value);
275   if (size == -1) {
276     // TensorShape.__len__ raises an error in the scenario where the shape is an
277     // unknown, which needs to be cleared.
278     // TODO(nareshmodi): ensure that this is actually a TensorShape.
279     PyErr_Clear();
280   }
281   return size;
282 }
283 
IsInteger(PyObject * py_value)284 bool IsInteger(PyObject* py_value) {
285 #if PY_MAJOR_VERSION >= 3
286   return PyLong_Check(py_value);
287 #else
288   return PyInt_Check(py_value) || PyLong_Check(py_value);
289 #endif
290 }
291 
292 // This function considers a Dimension._value of None to be valid, and sets the
293 // value to be -1 in that case.
ParseDimensionValue(const string & key,PyObject * py_value,TF_Status * status,int64_t * value)294 bool ParseDimensionValue(const string& key, PyObject* py_value,
295                          TF_Status* status, int64_t* value) {
296   if (IsInteger(py_value)) {
297     return ParseInt64Value(key, py_value, status, value);
298   }
299 
300   tensorflow::Safe_PyObjectPtr dimension_value(
301       PyObject_GetAttrString(py_value, "_value"));
302   if (dimension_value == nullptr) {
303     PyErr_Clear();
304     TF_SetStatus(
305         status, TF_INVALID_ARGUMENT,
306         tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
307                                     ", got ", py_value->ob_type->tp_name)
308             .c_str());
309     return false;
310   }
311 
312   if (dimension_value.get() == Py_None) {
313     *value = -1;
314     return true;
315   }
316 
317   return ParseInt64Value(key, dimension_value.get(), status, value);
318 }
319 
ParseStringValue(const string & key,PyObject * py_value,TF_Status * status,tensorflow::StringPiece * value)320 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
321                       tensorflow::StringPiece* value) {
322   if (PyBytes_Check(py_value)) {
323     Py_ssize_t size = 0;
324     char* buf = nullptr;
325     if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
326     *value = tensorflow::StringPiece(buf, size);
327     return true;
328   }
329 #if PY_MAJOR_VERSION >= 3
330   if (PyUnicode_Check(py_value)) {
331     Py_ssize_t size = 0;
332     const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
333     if (buf == nullptr) return false;
334     *value = tensorflow::StringPiece(buf, size);
335     return true;
336   }
337 #endif
338   TF_SetStatus(
339       status, TF_INVALID_ARGUMENT,
340       tensorflow::strings::StrCat("Expecting a string value for attr ", key,
341                                   ", got ", py_value->ob_type->tp_name)
342           .c_str());
343   return false;
344 }
345 
ParseBoolValue(const string & key,PyObject * py_value,TF_Status * status,unsigned char * value)346 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
347                     unsigned char* value) {
348   *value = PyObject_IsTrue(py_value);
349   return true;
350 }
351 
352 // The passed in py_value is expected to be an object of the python type
353 // dtypes.DType or an int.
ParseTypeValue(const string & key,PyObject * py_value,TF_Status * status,int * value)354 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
355                     int* value) {
356   if (IsInteger(py_value)) {
357     return ParseIntValue(key, py_value, status, value);
358   }
359 
360   tensorflow::Safe_PyObjectPtr py_type_enum(
361       PyObject_GetAttrString(py_value, "_type_enum"));
362   if (py_type_enum == nullptr) {
363     PyErr_Clear();
364     TF_SetStatus(
365         status, TF_INVALID_ARGUMENT,
366         tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
367                                     ", got ", py_value->ob_type->tp_name)
368             .c_str());
369     return false;
370   }
371 
372   return ParseIntValue(key, py_type_enum.get(), status, value);
373 }
374 
SetOpAttrList(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_list,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)375 bool SetOpAttrList(
376     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
377     TF_AttrType type,
378     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
379     TF_Status* status) {
380   if (!PySequence_Check(py_list)) {
381     TF_SetStatus(
382         status, TF_INVALID_ARGUMENT,
383         tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
384                                     ", got ", py_list->ob_type->tp_name)
385             .c_str());
386     return false;
387   }
388   const int num_values = PySequence_Size(py_list);
389   if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
390 
391 #define PARSE_LIST(c_type, parse_fn)                                      \
392   std::unique_ptr<c_type[]> values(new c_type[num_values]);               \
393   for (int i = 0; i < num_values; ++i) {                                  \
394     tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));   \
395     if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
396   }
397 
398   if (type == TF_ATTR_STRING) {
399     std::unique_ptr<const void*[]> values(new const void*[num_values]);
400     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
401     for (int i = 0; i < num_values; ++i) {
402       tensorflow::StringPiece value;
403       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
404       if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
405       values[i] = value.data();
406       lengths[i] = value.size();
407     }
408     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
409   } else if (type == TF_ATTR_INT) {
410     PARSE_LIST(int64_t, ParseInt64Value);
411     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
412   } else if (type == TF_ATTR_FLOAT) {
413     PARSE_LIST(float, ParseFloatValue);
414     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
415   } else if (type == TF_ATTR_BOOL) {
416     PARSE_LIST(unsigned char, ParseBoolValue);
417     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
418   } else if (type == TF_ATTR_TYPE) {
419     PARSE_LIST(int, ParseTypeValue);
420     TFE_OpSetAttrTypeList(op, key,
421                           reinterpret_cast<const TF_DataType*>(values.get()),
422                           num_values);
423   } else if (type == TF_ATTR_SHAPE) {
424     // Make one pass through the input counting the total number of
425     // dims across all the input lists.
426     int total_dims = 0;
427     for (int i = 0; i < num_values; ++i) {
428       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
429       if (py_value.get() != Py_None) {
430         if (!PySequence_Check(py_value.get())) {
431           TF_SetStatus(
432               status, TF_INVALID_ARGUMENT,
433               tensorflow::strings::StrCat(
434                   "Expecting None or sequence value for element", i,
435                   " of attr ", key, ", got ", py_value->ob_type->tp_name)
436                   .c_str());
437           return false;
438         }
439         const auto size = TensorShapeNumDims(py_value.get());
440         if (size >= 0) {
441           total_dims += size;
442         }
443       }
444     }
445     // Allocate a buffer that can fit all of the dims together.
446     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
447     // Copy the input dims into the buffer and set dims to point to
448     // the start of each list's dims.
449     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
450     std::unique_ptr<int[]> num_dims(new int[num_values]);
451     int64_t* offset = buffer.get();
452     for (int i = 0; i < num_values; ++i) {
453       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
454       if (py_value.get() == Py_None) {
455         dims[i] = nullptr;
456         num_dims[i] = -1;
457       } else {
458         const auto size = TensorShapeNumDims(py_value.get());
459         if (size == -1) {
460           dims[i] = nullptr;
461           num_dims[i] = -1;
462           continue;
463         }
464         dims[i] = offset;
465         num_dims[i] = size;
466         for (int j = 0; j < size; ++j) {
467           tensorflow::Safe_PyObjectPtr inner_py_value(
468               PySequence_ITEM(py_value.get(), j));
469           if (inner_py_value.get() == Py_None) {
470             *offset = -1;
471           } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
472                                           offset)) {
473             return false;
474           }
475           ++offset;
476         }
477       }
478     }
479     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
480                            status);
481     if (!status->status.ok()) return false;
482   } else if (type == TF_ATTR_FUNC) {
483     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
484     for (int i = 0; i < num_values; ++i) {
485       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
486       // Allow:
487       // (1) String function name, OR
488       // (2) A Python object with a .name attribute
489       //     (A crude test for being a
490       //     tensorflow.python.framework.function._DefinedFunction)
491       //     (which is what the various "defun" or "Defun" decorators do).
492       // And in the future also allow an object that can encapsulate
493       // the function name and its attribute values.
494       tensorflow::StringPiece func_name;
495       if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
496         PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
497         if (name_attr == nullptr ||
498             !ParseStringValue(key, name_attr, status, &func_name)) {
499           TF_SetStatus(
500               status, TF_INVALID_ARGUMENT,
501               tensorflow::strings::StrCat(
502                   "unable to set function value attribute from a ",
503                   py_value.get()->ob_type->tp_name,
504                   " object. If you think this is an error, please file an "
505                   "issue at "
506                   "https://github.com/tensorflow/tensorflow/issues/new")
507                   .c_str());
508           return false;
509         }
510       }
511       funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
512       if (!status->status.ok()) return false;
513     }
514     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
515     if (!status->status.ok()) return false;
516   } else {
517     TF_SetStatus(status, TF_UNIMPLEMENTED,
518                  tensorflow::strings::StrCat("Attr ", key,
519                                              " has unhandled list type ", type)
520                      .c_str());
521     return false;
522   }
523 #undef PARSE_LIST
524   return true;
525 }
526 
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)527 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
528                 TF_Status* status) {
529   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
530   for (const auto& attr : func.attr()) {
531     if (!status->status.ok()) return nullptr;
532     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
533     if (!status->status.ok()) return nullptr;
534   }
535   return func_op;
536 }
537 
SetOpAttrListDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * key,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)538 void SetOpAttrListDefault(
539     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
540     const char* key, TF_AttrType type,
541     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
542     TF_Status* status) {
543   if (type == TF_ATTR_STRING) {
544     int num_values = attr.default_value().list().s_size();
545     std::unique_ptr<const void*[]> values(new const void*[num_values]);
546     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
547     (*attr_list_sizes)[key] = num_values;
548     for (int i = 0; i < num_values; i++) {
549       const string& v = attr.default_value().list().s(i);
550       values[i] = v.data();
551       lengths[i] = v.size();
552     }
553     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
554   } else if (type == TF_ATTR_INT) {
555     int num_values = attr.default_value().list().i_size();
556     std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
557     (*attr_list_sizes)[key] = num_values;
558     for (int i = 0; i < num_values; i++) {
559       values[i] = attr.default_value().list().i(i);
560     }
561     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
562   } else if (type == TF_ATTR_FLOAT) {
563     int num_values = attr.default_value().list().f_size();
564     std::unique_ptr<float[]> values(new float[num_values]);
565     (*attr_list_sizes)[key] = num_values;
566     for (int i = 0; i < num_values; i++) {
567       values[i] = attr.default_value().list().f(i);
568     }
569     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
570   } else if (type == TF_ATTR_BOOL) {
571     int num_values = attr.default_value().list().b_size();
572     std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
573     (*attr_list_sizes)[key] = num_values;
574     for (int i = 0; i < num_values; i++) {
575       values[i] = attr.default_value().list().b(i);
576     }
577     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
578   } else if (type == TF_ATTR_TYPE) {
579     int num_values = attr.default_value().list().type_size();
580     std::unique_ptr<int[]> values(new int[num_values]);
581     (*attr_list_sizes)[key] = num_values;
582     for (int i = 0; i < num_values; i++) {
583       values[i] = attr.default_value().list().type(i);
584     }
585     TFE_OpSetAttrTypeList(op, key,
586                           reinterpret_cast<const TF_DataType*>(values.get()),
587                           attr.default_value().list().type_size());
588   } else if (type == TF_ATTR_SHAPE) {
589     int num_values = attr.default_value().list().shape_size();
590     (*attr_list_sizes)[key] = num_values;
591     int total_dims = 0;
592     for (int i = 0; i < num_values; ++i) {
593       if (!attr.default_value().list().shape(i).unknown_rank()) {
594         total_dims += attr.default_value().list().shape(i).dim_size();
595       }
596     }
597     // Allocate a buffer that can fit all of the dims together.
598     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
599     // Copy the input dims into the buffer and set dims to point to
600     // the start of each list's dims.
601     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
602     std::unique_ptr<int[]> num_dims(new int[num_values]);
603     int64_t* offset = buffer.get();
604     for (int i = 0; i < num_values; ++i) {
605       const auto& shape = attr.default_value().list().shape(i);
606       if (shape.unknown_rank()) {
607         dims[i] = nullptr;
608         num_dims[i] = -1;
609       } else {
610         for (int j = 0; j < shape.dim_size(); j++) {
611           *offset = shape.dim(j).size();
612           ++offset;
613         }
614       }
615     }
616     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
617                            status);
618   } else if (type == TF_ATTR_FUNC) {
619     int num_values = attr.default_value().list().func_size();
620     (*attr_list_sizes)[key] = num_values;
621     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
622     for (int i = 0; i < num_values; i++) {
623       funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
624     }
625     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
626   } else {
627     TF_SetStatus(status, TF_UNIMPLEMENTED,
628                  "Lists of tensors are not yet implemented for default valued "
629                  "attributes for an operation.");
630   }
631 }
632 
SetOpAttrScalar(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_value,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)633 bool SetOpAttrScalar(
634     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
635     TF_AttrType type,
636     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
637     TF_Status* status) {
638   if (type == TF_ATTR_STRING) {
639     tensorflow::StringPiece value;
640     if (!ParseStringValue(key, py_value, status, &value)) return false;
641     TFE_OpSetAttrString(op, key, value.data(), value.size());
642   } else if (type == TF_ATTR_INT) {
643     int64_t value;
644     if (!ParseInt64Value(key, py_value, status, &value)) return false;
645     TFE_OpSetAttrInt(op, key, value);
646     // attr_list_sizes is set for all int attributes (since at this point we are
647     // not aware if that attribute might be used to calculate the size of an
648     // output list or not).
649     if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
650   } else if (type == TF_ATTR_FLOAT) {
651     float value;
652     if (!ParseFloatValue(key, py_value, status, &value)) return false;
653     TFE_OpSetAttrFloat(op, key, value);
654   } else if (type == TF_ATTR_BOOL) {
655     unsigned char value;
656     if (!ParseBoolValue(key, py_value, status, &value)) return false;
657     TFE_OpSetAttrBool(op, key, value);
658   } else if (type == TF_ATTR_TYPE) {
659     int value;
660     if (!ParseTypeValue(key, py_value, status, &value)) return false;
661     TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
662   } else if (type == TF_ATTR_SHAPE) {
663     if (py_value == Py_None) {
664       TFE_OpSetAttrShape(op, key, nullptr, -1, status);
665     } else {
666       if (!PySequence_Check(py_value)) {
667         TF_SetStatus(status, TF_INVALID_ARGUMENT,
668                      tensorflow::strings::StrCat(
669                          "Expecting None or sequence value for attr", key,
670                          ", got ", py_value->ob_type->tp_name)
671                          .c_str());
672         return false;
673       }
674       const auto num_dims = TensorShapeNumDims(py_value);
675       if (num_dims == -1) {
676         TFE_OpSetAttrShape(op, key, nullptr, -1, status);
677         return true;
678       }
679       std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
680       for (int i = 0; i < num_dims; ++i) {
681         tensorflow::Safe_PyObjectPtr inner_py_value(
682             PySequence_ITEM(py_value, i));
683         if (inner_py_value.get() == Py_None) {
684           dims[i] = -1;
685         } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
686                                         &dims[i])) {
687           return false;
688         }
689       }
690       TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
691     }
692     if (!status->status.ok()) return false;
693   } else if (type == TF_ATTR_FUNC) {
694     // Allow:
695     // (1) String function name, OR
696     // (2) A Python object with a .name attribute
697     //     (A crude test for being a
698     //     tensorflow.python.framework.function._DefinedFunction)
699     //     (which is what the various "defun" or "Defun" decorators do).
700     // And in the future also allow an object that can encapsulate
701     // the function name and its attribute values.
702     tensorflow::StringPiece func_name;
703     if (!ParseStringValue(key, py_value, status, &func_name)) {
704       PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
705       if (name_attr == nullptr ||
706           !ParseStringValue(key, name_attr, status, &func_name)) {
707         TF_SetStatus(
708             status, TF_INVALID_ARGUMENT,
709             tensorflow::strings::StrCat(
710                 "unable to set function value attribute from a ",
711                 py_value->ob_type->tp_name,
712                 " object. If you think this is an error, please file an issue "
713                 "at https://github.com/tensorflow/tensorflow/issues/new")
714                 .c_str());
715         return false;
716       }
717     }
718     TF_SetStatus(status, TF_OK, "");
719     TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
720   } else {
721     TF_SetStatus(
722         status, TF_UNIMPLEMENTED,
723         tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
724             .c_str());
725     return false;
726   }
727   return true;
728 }
729 
SetOpAttrScalarDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)730 void SetOpAttrScalarDefault(
731     TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
732     const char* attr_name,
733     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
734     TF_Status* status) {
735   SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
736   if (default_value.value_case() == tensorflow::AttrValue::kI) {
737     (*attr_list_sizes)[attr_name] = default_value.i();
738   }
739 }
740 
741 // start_index is the index at which the Tuple/List attrs will start getting
742 // processed.
SetOpAttrs(TFE_Context * ctx,TFE_Op * op,PyObject * attrs,int start_index,TF_Status * out_status)743 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
744                 TF_Status* out_status) {
745   if (attrs == Py_None) return;
746   Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
747   if ((len & 1) != 0) {
748     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
749                  "Expecting attrs tuple to have even length.");
750     return;
751   }
752   // Parse attrs
753   for (Py_ssize_t i = 0; i < len; i += 2) {
754     PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
755     PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
756 #if PY_MAJOR_VERSION >= 3
757     const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
758                                             : PyUnicode_AsUTF8(py_key);
759 #else
760     const char* key = PyBytes_AsString(py_key);
761 #endif
762     unsigned char is_list = 0;
763     const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
764     if (!out_status->status.ok()) return;
765     if (is_list != 0) {
766       if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
767         return;
768     } else {
769       if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
770         return;
771     }
772   }
773 }
774 
775 // This function will set the op attrs required. If an attr has the value of
776 // None, then it will read the AttrDef to get the default value and set that
777 // instead. Any failure in this function will simply fall back to the slow
778 // path.
SetOpAttrWithDefaults(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * attr_name,PyObject * attr_value,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)779 void SetOpAttrWithDefaults(
780     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
781     const char* attr_name, PyObject* attr_value,
782     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
783     TF_Status* status) {
784   unsigned char is_list = 0;
785   const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
786   if (!status->status.ok()) return;
787   if (attr_value == Py_None) {
788     if (is_list != 0) {
789       SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
790                            status);
791     } else {
792       SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
793                              attr_list_sizes, status);
794     }
795   } else {
796     if (is_list != 0) {
797       SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
798                     status);
799     } else {
800       SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
801                       status);
802     }
803   }
804 }
805 
GetPythonObjectFromInt(int num)806 PyObject* GetPythonObjectFromInt(int num) {
807 #if PY_MAJOR_VERSION >= 3
808   return PyLong_FromLong(num);
809 #else
810   return PyInt_FromLong(num);
811 #endif
812 }
813 
814 // Python subclass of Exception that is created on not ok Status.
815 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
816 PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
817 
818 // Python subclass of Exception that is created to signal fallback.
819 PyObject* fallback_exception_class = nullptr;
820 
821 // Python function that returns input gradients given output gradients.
822 PyObject* gradient_function = nullptr;
823 
824 // Python function that returns output gradients given input gradients.
825 PyObject* forward_gradient_function = nullptr;
826 
827 static std::atomic<int64_t> _uid;
828 
829 }  // namespace
830 
GetStatus()831 TF_Status* GetStatus() {
832   TF_Status* maybe_status = ReleaseThreadLocalStatus();
833   if (maybe_status) {
834     TF_SetStatus(maybe_status, TF_OK, "");
835     return maybe_status;
836   } else {
837     return TF_NewStatus();
838   }
839 }
840 
ReturnStatus(TF_Status * status)841 void ReturnStatus(TF_Status* status) {
842   TF_SetStatus(status, TF_OK, "");
843   thread_local_tf_status.reset(status);
844 }
845 
TFE_Py_Execute(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_OutputTensorHandles * outputs,TF_Status * out_status)846 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
847                     const char* op_name, TFE_InputTensorHandles* inputs,
848                     PyObject* attrs, TFE_OutputTensorHandles* outputs,
849                     TF_Status* out_status) {
850   TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
851                            /*cancellation_manager=*/nullptr, outputs,
852                            out_status);
853 }
854 
TFE_Py_ExecuteCancelable(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_CancellationManager * cancellation_manager,TFE_OutputTensorHandles * outputs,TF_Status * out_status)855 void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
856                               const char* op_name,
857                               TFE_InputTensorHandles* inputs, PyObject* attrs,
858                               TFE_CancellationManager* cancellation_manager,
859                               TFE_OutputTensorHandles* outputs,
860                               TF_Status* out_status) {
861   tensorflow::profiler::TraceMe activity(
862       "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
863 
864   TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
865 
866   auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
867   if (!out_status->status.ok()) return;
868 
869   tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
870       tensorflow::StackTrace::kStackTraceInitialSize));
871 
872   for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
873     TFE_OpAddInput(op, inputs->at(i), out_status);
874   }
875   if (cancellation_manager && out_status->status.ok()) {
876     TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
877   }
878   if (out_status->status.ok()) {
879     SetOpAttrs(ctx, op, attrs, 0, out_status);
880   }
881   Py_BEGIN_ALLOW_THREADS;
882 
883   int num_outputs = outputs->size();
884 
885   if (out_status->status.ok()) {
886     TFE_Execute(op, outputs->data(), &num_outputs, out_status);
887   }
888 
889   if (out_status->status.ok()) {
890     outputs->resize(num_outputs);
891   } else {
892     TF_SetStatus(out_status, TF_GetCode(out_status),
893                  tensorflow::strings::StrCat(TF_Message(out_status),
894                                              " [Op:", op_name, "]")
895                      .c_str());
896   }
897 
898   Py_END_ALLOW_THREADS;
899 }
900 
TFE_Py_RegisterExceptionClass(PyObject * e)901 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
902   tensorflow::mutex_lock l(exception_class_mutex);
903   if (exception_class != nullptr) {
904     Py_DECREF(exception_class);
905   }
906   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
907     exception_class = nullptr;
908     PyErr_SetString(PyExc_TypeError,
909                     "TFE_Py_RegisterExceptionClass: "
910                     "Registered class should be subclass of Exception.");
911     return nullptr;
912   }
913 
914   Py_INCREF(e);
915   exception_class = e;
916   Py_RETURN_NONE;
917 }
918 
TFE_Py_RegisterFallbackExceptionClass(PyObject * e)919 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
920   if (fallback_exception_class != nullptr) {
921     Py_DECREF(fallback_exception_class);
922   }
923   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
924     fallback_exception_class = nullptr;
925     PyErr_SetString(PyExc_TypeError,
926                     "TFE_Py_RegisterFallbackExceptionClass: "
927                     "Registered class should be subclass of Exception.");
928     return nullptr;
929   } else {
930     Py_INCREF(e);
931     fallback_exception_class = e;
932     Py_RETURN_NONE;
933   }
934 }
935 
TFE_Py_RegisterGradientFunction(PyObject * e)936 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
937   if (gradient_function != nullptr) {
938     Py_DECREF(gradient_function);
939   }
940   if (!PyCallable_Check(e)) {
941     gradient_function = nullptr;
942     PyErr_SetString(PyExc_TypeError,
943                     "TFE_Py_RegisterGradientFunction: "
944                     "Registered object should be function.");
945     return nullptr;
946   } else {
947     Py_INCREF(e);
948     gradient_function = e;
949     Py_RETURN_NONE;
950   }
951 }
952 
TFE_Py_RegisterJVPFunction(PyObject * e)953 PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
954   if (forward_gradient_function != nullptr) {
955     Py_DECREF(forward_gradient_function);
956   }
957   if (!PyCallable_Check(e)) {
958     forward_gradient_function = nullptr;
959     PyErr_SetString(PyExc_TypeError,
960                     "TFE_Py_RegisterJVPFunction: "
961                     "Registered object should be function.");
962     return nullptr;
963   } else {
964     Py_INCREF(e);
965     forward_gradient_function = e;
966     Py_RETURN_NONE;
967   }
968 }
969 
RaiseFallbackException(const char * message)970 void RaiseFallbackException(const char* message) {
971   if (fallback_exception_class != nullptr) {
972     PyErr_SetString(fallback_exception_class, message);
973     return;
974   }
975 
976   PyErr_SetString(
977       PyExc_RuntimeError,
978       tensorflow::strings::StrCat(
979           "Fallback exception type not set, attempting to fallback due to ",
980           message)
981           .data());
982 }
983 
984 // Format and return `status`' error message with the attached stack trace if
985 // available. `status` must have an error.
FormatErrorStatusStackTrace(const tensorflow::Status & status)986 std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
987   tensorflow::DCheckPyGilState();
988   DCHECK(!status.ok());
989 
990   if (status.stack_trace().empty()) return status.error_message();
991 
992   const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace();
993 
994   PyObject* linecache = PyImport_ImportModule("linecache");
995   PyObject* getline =
996       PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
997   DCHECK(getline);
998 
999   std::ostringstream result;
1000   result << "Exception originated from\n\n";
1001 
1002   for (const tensorflow::StackFrame& stack_frame : stack_trace) {
1003     PyObject* line_str_obj = PyObject_CallFunction(
1004         getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
1005         stack_frame.line_number);
1006     tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
1007     tensorflow::str_util::RemoveWhitespaceContext(&line_str);
1008     result << "  File \"" << stack_frame.file_name << "\", line "
1009            << stack_frame.line_number << ", in " << stack_frame.function_name
1010            << '\n';
1011 
1012     if (!line_str.empty()) result << "    " << line_str << '\n';
1013     Py_XDECREF(line_str_obj);
1014   }
1015 
1016   Py_DecRef(getline);
1017   Py_DecRef(linecache);
1018 
1019   result << '\n' << status.error_message();
1020   return result.str();
1021 }
1022 
MaybeRaiseExceptionFromTFStatus(TF_Status * status,PyObject * exception)1023 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
1024   if (status->status.ok()) return 0;
1025   const char* msg = TF_Message(status);
1026   if (exception == nullptr) {
1027     tensorflow::mutex_lock l(exception_class_mutex);
1028     if (exception_class != nullptr) {
1029       tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1030           "si", FormatErrorStatusStackTrace(status->status).c_str(),
1031           TF_GetCode(status)));
1032       if (PyErr_Occurred()) {
1033         // NOTE: This hides the actual error (i.e. the reason `status` was not
1034         // TF_OK), but there is nothing we can do at this point since we can't
1035         // generate a reasonable error from the status.
1036         // Consider adding a message explaining this.
1037         return -1;
1038       }
1039       PyErr_SetObject(exception_class, val.get());
1040       return -1;
1041     } else {
1042       exception = PyExc_RuntimeError;
1043     }
1044   }
1045   // May be update already set exception.
1046   PyErr_SetString(exception, msg);
1047   return -1;
1048 }
1049 
MaybeRaiseExceptionFromStatus(const tensorflow::Status & status,PyObject * exception)1050 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
1051                                   PyObject* exception) {
1052   if (status.ok()) return 0;
1053   const char* msg = status.error_message().c_str();
1054   if (exception == nullptr) {
1055     tensorflow::mutex_lock l(exception_class_mutex);
1056     if (exception_class != nullptr) {
1057       tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1058           "si", FormatErrorStatusStackTrace(status).c_str(), status.code()));
1059       PyErr_SetObject(exception_class, val.get());
1060       return -1;
1061     } else {
1062       exception = PyExc_RuntimeError;
1063     }
1064   }
1065   // May be update already set exception.
1066   PyErr_SetString(exception, msg);
1067   return -1;
1068 }
1069 
TFE_GetPythonString(PyObject * o)1070 const char* TFE_GetPythonString(PyObject* o) {
1071 #if PY_MAJOR_VERSION >= 3
1072   if (PyBytes_Check(o)) {
1073     return PyBytes_AsString(o);
1074   } else {
1075     return PyUnicode_AsUTF8(o);
1076   }
1077 #else
1078   return PyBytes_AsString(o);
1079 #endif
1080 }
1081 
get_uid()1082 int64_t get_uid() { return _uid++; }
1083 
TFE_Py_UID()1084 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
1085 
TFE_DeleteContextCapsule(PyObject * context)1086 void TFE_DeleteContextCapsule(PyObject* context) {
1087   TFE_Context* ctx =
1088       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
1089   auto op = ReleaseThreadLocalOp(ctx);
1090   op.reset();
1091   TFE_DeleteContext(ctx);
1092 }
1093 
MakeInt(PyObject * integer)1094 static tensorflow::int64 MakeInt(PyObject* integer) {
1095 #if PY_MAJOR_VERSION >= 3
1096   return PyLong_AsLong(integer);
1097 #else
1098   return PyInt_AsLong(integer);
1099 #endif
1100 }
1101 
FastTensorId(PyObject * tensor)1102 static tensorflow::int64 FastTensorId(PyObject* tensor) {
1103   if (EagerTensor_CheckExact(tensor)) {
1104     return PyEagerTensor_ID(tensor);
1105   }
1106   PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
1107   if (id_field == nullptr) {
1108     return -1;
1109   }
1110   int64_t id = MakeInt(id_field);
1111   Py_DECREF(id_field);
1112   return id;
1113 }
1114 
1115 namespace tensorflow {
PyTensor_DataType(PyObject * tensor)1116 DataType PyTensor_DataType(PyObject* tensor) {
1117   if (EagerTensor_CheckExact(tensor)) {
1118     return PyEagerTensor_Dtype(tensor);
1119   } else {
1120 #if PY_MAJOR_VERSION < 3
1121     // Python 2.x:
1122     static PyObject* dtype_attr = PyString_InternFromString("dtype");
1123     static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
1124 #else
1125     // Python 3.x:
1126     static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
1127     static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
1128 #endif
1129     Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
1130     if (!dtype_field) {
1131       return DT_INVALID;
1132     }
1133 
1134     Safe_PyObjectPtr enum_field(
1135         PyObject_GetAttr(dtype_field.get(), type_enum_attr));
1136     if (!enum_field) {
1137       return DT_INVALID;
1138     }
1139 
1140     return static_cast<DataType>(MakeInt(enum_field.get()));
1141   }
1142 }
1143 }  // namespace tensorflow
1144 
1145 class PyTapeTensor {
1146  public:
PyTapeTensor(int64_t id,tensorflow::DataType dtype,const tensorflow::TensorShape & shape)1147   PyTapeTensor(int64_t id, tensorflow::DataType dtype,
1148                const tensorflow::TensorShape& shape)
1149       : id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(int64_t id,tensorflow::DataType dtype,PyObject * shape)1150   PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape)
1151       : id_(id), dtype_(dtype), shape_(shape) {
1152     Py_INCREF(absl::get<1>(shape_));
1153   }
PyTapeTensor(const PyTapeTensor & other)1154   PyTapeTensor(const PyTapeTensor& other) {
1155     id_ = other.id_;
1156     dtype_ = other.dtype_;
1157     shape_ = other.shape_;
1158     if (shape_.index() == 1) {
1159       Py_INCREF(absl::get<1>(shape_));
1160     }
1161   }
1162 
~PyTapeTensor()1163   ~PyTapeTensor() {
1164     if (shape_.index() == 1) {
1165       Py_DECREF(absl::get<1>(shape_));
1166     }
1167   }
1168   PyObject* GetShape() const;
GetPyDType() const1169   PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
GetID() const1170   tensorflow::int64 GetID() const { return id_; }
GetDType() const1171   tensorflow::DataType GetDType() const { return dtype_; }
1172 
1173   PyObject* OnesLike() const;
1174   PyObject* ZerosLike() const;
1175 
1176  private:
1177   tensorflow::int64 id_;
1178   tensorflow::DataType dtype_;
1179 
1180   // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
1181   // PyObject is the tensor itself. This is used to support tf.shape(tensor) for
1182   // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
1183   // tensors.
1184   absl::variant<tensorflow::TensorShape, PyObject*> shape_;
1185 };
1186 
1187 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
1188 
1189 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
1190                                                   PyTapeTensor> {
1191  public:
PyVSpace(PyObject * py_vspace)1192   explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
1193     Py_INCREF(py_vspace_);
1194   }
1195 
Initialize()1196   tensorflow::Status Initialize() {
1197     num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
1198     if (num_elements_ == nullptr) {
1199       return tensorflow::errors::InvalidArgument("invalid vspace");
1200     }
1201     aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
1202     if (aggregate_fn_ == nullptr) {
1203       return tensorflow::errors::InvalidArgument("invalid vspace");
1204     }
1205     zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
1206     if (zeros_fn_ == nullptr) {
1207       return tensorflow::errors::InvalidArgument("invalid vspace");
1208     }
1209     zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
1210     if (zeros_like_fn_ == nullptr) {
1211       return tensorflow::errors::InvalidArgument("invalid vspace");
1212     }
1213     ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
1214     if (ones_fn_ == nullptr) {
1215       return tensorflow::errors::InvalidArgument("invalid vspace");
1216     }
1217     ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
1218     if (ones_like_fn_ == nullptr) {
1219       return tensorflow::errors::InvalidArgument("invalid vspace");
1220     }
1221     graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
1222     if (graph_shape_fn_ == nullptr) {
1223       return tensorflow::errors::InvalidArgument("invalid vspace");
1224     }
1225     return tensorflow::Status::OK();
1226   }
1227 
~PyVSpace()1228   ~PyVSpace() override {
1229     Py_XDECREF(num_elements_);
1230     Py_XDECREF(aggregate_fn_);
1231     Py_XDECREF(zeros_fn_);
1232     Py_XDECREF(zeros_like_fn_);
1233     Py_XDECREF(ones_fn_);
1234     Py_XDECREF(ones_like_fn_);
1235     Py_XDECREF(graph_shape_fn_);
1236 
1237     Py_DECREF(py_vspace_);
1238   }
1239 
NumElements(PyObject * tensor) const1240   tensorflow::int64 NumElements(PyObject* tensor) const final {
1241     if (EagerTensor_CheckExact(tensor)) {
1242       return PyEagerTensor_NumElements(tensor);
1243     }
1244     PyObject* arglist =
1245         Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1246     PyObject* result = PyEval_CallObject(num_elements_, arglist);
1247     Py_DECREF(arglist);
1248     if (result == nullptr) {
1249       // The caller detects whether a python exception has been raised.
1250       return -1;
1251     }
1252     int64_t r = MakeInt(result);
1253     Py_DECREF(result);
1254     return r;
1255   }
1256 
AggregateGradients(tensorflow::gtl::ArraySlice<PyObject * > gradient_tensors) const1257   PyObject* AggregateGradients(
1258       tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1259     PyObject* list = PyList_New(gradient_tensors.size());
1260     for (int i = 0; i < gradient_tensors.size(); ++i) {
1261       // Note: stealing a reference to the gradient tensors.
1262       CHECK(gradient_tensors[i] != nullptr);
1263       CHECK(gradient_tensors[i] != Py_None);
1264       PyList_SET_ITEM(list, i,
1265                       reinterpret_cast<PyObject*>(gradient_tensors[i]));
1266     }
1267     PyObject* arglist = Py_BuildValue("(O)", list);
1268     CHECK(arglist != nullptr);
1269     PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1270     Py_DECREF(arglist);
1271     Py_DECREF(list);
1272     return result;
1273   }
1274 
TensorId(PyObject * tensor) const1275   tensorflow::int64 TensorId(PyObject* tensor) const final {
1276     return FastTensorId(tensor);
1277   }
1278 
MarkAsResult(PyObject * gradient) const1279   void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1280 
Ones(PyObject * shape,PyObject * dtype) const1281   PyObject* Ones(PyObject* shape, PyObject* dtype) const {
1282     if (PyErr_Occurred()) {
1283       return nullptr;
1284     }
1285     PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1286     PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1287     Py_DECREF(arg_list);
1288     return result;
1289   }
1290 
OnesLike(PyObject * tensor) const1291   PyObject* OnesLike(PyObject* tensor) const {
1292     if (PyErr_Occurred()) {
1293       return nullptr;
1294     }
1295     return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
1296   }
1297 
1298   // Builds a tensor filled with ones with the same shape and dtype as `t`.
BuildOnesLike(const PyTapeTensor & t,PyObject ** result) const1299   Status BuildOnesLike(const PyTapeTensor& t,
1300                        PyObject** result) const override {
1301     *result = t.OnesLike();
1302     return Status::OK();
1303   }
1304 
Zeros(PyObject * shape,PyObject * dtype) const1305   PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
1306     if (PyErr_Occurred()) {
1307       return nullptr;
1308     }
1309     PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1310     PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1311     Py_DECREF(arg_list);
1312     return result;
1313   }
1314 
ZerosLike(PyObject * tensor) const1315   PyObject* ZerosLike(PyObject* tensor) const {
1316     if (PyErr_Occurred()) {
1317       return nullptr;
1318     }
1319     return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
1320   }
1321 
GraphShape(PyObject * tensor) const1322   PyObject* GraphShape(PyObject* tensor) const {
1323     PyObject* arg_list = Py_BuildValue("(O)", tensor);
1324     PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1325     Py_DECREF(arg_list);
1326     return result;
1327   }
1328 
CallBackwardFunction(const string & op_type,PyBackwardFunction * backward_function,const std::vector<tensorflow::int64> & unneeded_gradients,tensorflow::gtl::ArraySlice<PyObject * > output_gradients,absl::Span<PyObject * > result) const1329   tensorflow::Status CallBackwardFunction(
1330       const string& op_type, PyBackwardFunction* backward_function,
1331       const std::vector<tensorflow::int64>& unneeded_gradients,
1332       tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1333       absl::Span<PyObject*> result) const final {
1334     PyObject* grads = PyTuple_New(output_gradients.size());
1335     for (int i = 0; i < output_gradients.size(); ++i) {
1336       if (output_gradients[i] == nullptr) {
1337         Py_INCREF(Py_None);
1338         PyTuple_SET_ITEM(grads, i, Py_None);
1339       } else {
1340         PyTuple_SET_ITEM(grads, i,
1341                          reinterpret_cast<PyObject*>(output_gradients[i]));
1342       }
1343     }
1344     PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
1345     Py_DECREF(grads);
1346     if (py_result == nullptr) {
1347       return tensorflow::errors::Internal("gradient function threw exceptions");
1348     }
1349     PyObject* seq =
1350         PySequence_Fast(py_result, "expected a sequence of gradients");
1351     if (seq == nullptr) {
1352       return tensorflow::errors::InvalidArgument(
1353           "gradient function did not return a list");
1354     }
1355     int len = PySequence_Fast_GET_SIZE(seq);
1356     if (len != result.size()) {
1357       return tensorflow::errors::Internal(
1358           "Recorded operation '", op_type,
1359           "' returned too few gradients. Expected ", result.size(),
1360           " but received ", len);
1361     }
1362     PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1363     VLOG(1) << "Gradient length is " << len;
1364     for (int i = 0; i < len; ++i) {
1365       PyObject* item = seq_array[i];
1366       if (item == Py_None) {
1367         result[i] = nullptr;
1368       } else {
1369         Py_INCREF(item);
1370         result[i] = item;
1371       }
1372     }
1373     Py_DECREF(seq);
1374     Py_DECREF(py_result);
1375     return tensorflow::Status::OK();
1376   }
1377 
DeleteGradient(PyObject * tensor) const1378   void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1379 
TapeTensorFromGradient(PyObject * tensor) const1380   PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
1381     return TapeTensorFromTensor(tensor);
1382   }
1383 
1384  private:
1385   PyObject* py_vspace_;
1386 
1387   PyObject* num_elements_;
1388   PyObject* aggregate_fn_;
1389   PyObject* zeros_fn_;
1390   PyObject* zeros_like_fn_;
1391   PyObject* ones_fn_;
1392   PyObject* ones_like_fn_;
1393   PyObject* graph_shape_fn_;
1394 };
1395 PyVSpace* py_vspace = nullptr;
1396 
1397 bool HasAccumulator();
1398 
TFE_Py_RegisterVSpace(PyObject * e)1399 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1400   if (py_vspace != nullptr) {
1401     if (HasAccumulator()) {
1402       // Accumulators reference py_vspace, so we can't swap it out while one is
1403       // active. This is unlikely to ever happen.
1404       MaybeRaiseExceptionFromStatus(
1405           tensorflow::errors::Internal(
1406               "Can't change the vspace implementation while a "
1407               "forward accumulator is active."),
1408           nullptr);
1409     }
1410     delete py_vspace;
1411   }
1412 
1413   py_vspace = new PyVSpace(e);
1414   auto status = py_vspace->Initialize();
1415   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1416     delete py_vspace;
1417     return nullptr;
1418   }
1419 
1420   Py_RETURN_NONE;
1421 }
1422 
GetShape() const1423 PyObject* PyTapeTensor::GetShape() const {
1424   if (shape_.index() == 0) {
1425     auto& shape = absl::get<0>(shape_);
1426     PyObject* py_shape = PyTuple_New(shape.dims());
1427     for (int i = 0; i < shape.dims(); ++i) {
1428       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1429     }
1430 
1431     return py_shape;
1432   }
1433 
1434   return py_vspace->GraphShape(absl::get<1>(shape_));
1435 }
1436 
OnesLike() const1437 PyObject* PyTapeTensor::OnesLike() const {
1438   if (shape_.index() == 1) {
1439     PyObject* tensor = absl::get<1>(shape_);
1440     return py_vspace->OnesLike(tensor);
1441   }
1442   PyObject* py_shape = GetShape();
1443   PyObject* dtype_field = GetPyDType();
1444   PyObject* result = py_vspace->Ones(py_shape, dtype_field);
1445   Py_DECREF(dtype_field);
1446   Py_DECREF(py_shape);
1447   return result;
1448 }
1449 
ZerosLike() const1450 PyObject* PyTapeTensor::ZerosLike() const {
1451   if (GetDType() == tensorflow::DT_RESOURCE) {
1452     // Gradient functions for ops which return resource tensors accept
1453     // None. This is the behavior of py_vspace->Zeros, but checking here avoids
1454     // issues with ZerosLike.
1455     Py_RETURN_NONE;
1456   }
1457   if (shape_.index() == 1) {
1458     PyObject* tensor = absl::get<1>(shape_);
1459     return py_vspace->ZerosLike(tensor);
1460   }
1461   PyObject* py_shape = GetShape();
1462   PyObject* dtype_field = GetPyDType();
1463   PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
1464   Py_DECREF(dtype_field);
1465   Py_DECREF(py_shape);
1466   return result;
1467 }
1468 
1469 // Keeps track of all variables that have been accessed during execution.
1470 class VariableWatcher {
1471  public:
VariableWatcher()1472   VariableWatcher() {}
1473 
~VariableWatcher()1474   ~VariableWatcher() {
1475     for (const IdAndVariable& v : watched_variables_) {
1476       Py_DECREF(v.variable);
1477     }
1478   }
1479 
WatchVariable(PyObject * v)1480   tensorflow::int64 WatchVariable(PyObject* v) {
1481     tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1482     if (handle == nullptr) {
1483       return -1;
1484     }
1485     int64_t id = FastTensorId(handle.get());
1486 
1487     tensorflow::mutex_lock l(watched_variables_mu_);
1488     auto insert_result = watched_variables_.emplace(id, v);
1489 
1490     if (insert_result.second) {
1491       // Only increment the reference count if we aren't already watching this
1492       // variable.
1493       Py_INCREF(v);
1494     }
1495 
1496     return id;
1497   }
1498 
GetVariablesAsPyTuple()1499   PyObject* GetVariablesAsPyTuple() {
1500     tensorflow::mutex_lock l(watched_variables_mu_);
1501     PyObject* result = PyTuple_New(watched_variables_.size());
1502     Py_ssize_t pos = 0;
1503     for (const IdAndVariable& id_and_variable : watched_variables_) {
1504       PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1505       Py_INCREF(id_and_variable.variable);
1506     }
1507     return result;
1508   }
1509 
1510  private:
1511   // We store an IdAndVariable in the map since the map needs to be locked
1512   // during insert, but should not call back into python during insert to avoid
1513   // deadlocking with the GIL.
1514   struct IdAndVariable {
1515     tensorflow::int64 id;
1516     PyObject* variable;
1517 
IdAndVariableVariableWatcher::IdAndVariable1518     IdAndVariable(int64_t id, PyObject* variable)
1519         : id(id), variable(variable) {}
1520   };
1521   struct CompareById {
operator ()VariableWatcher::CompareById1522     bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1523       return lhs.id < rhs.id;
1524     }
1525   };
1526 
1527   tensorflow::mutex watched_variables_mu_;
1528   std::set<IdAndVariable, CompareById> watched_variables_
1529       TF_GUARDED_BY(watched_variables_mu_);
1530 };
1531 
1532 class GradientTape
1533     : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1534                                              PyTapeTensor> {
1535  public:
GradientTape(bool persistent,bool watch_accessed_variables)1536   explicit GradientTape(bool persistent, bool watch_accessed_variables)
1537       : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1538                                         PyTapeTensor>(persistent),
1539         watch_accessed_variables_(watch_accessed_variables) {}
1540 
~GradientTape()1541   virtual ~GradientTape() {}
1542 
VariableAccessed(PyObject * v)1543   void VariableAccessed(PyObject* v) {
1544     if (watch_accessed_variables_) {
1545       WatchVariable(v);
1546     }
1547   }
1548 
WatchVariable(PyObject * v)1549   void WatchVariable(PyObject* v) {
1550     int64_t id = variable_watcher_.WatchVariable(v);
1551 
1552     if (!PyErr_Occurred()) {
1553       this->Watch(id);
1554     }
1555   }
1556 
GetVariablesAsPyTuple()1557   PyObject* GetVariablesAsPyTuple() {
1558     return variable_watcher_.GetVariablesAsPyTuple();
1559   }
1560 
1561  private:
1562   bool watch_accessed_variables_;
1563   VariableWatcher variable_watcher_;
1564 };
1565 
1566 typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
1567                                               PyTapeTensor>
1568     ForwardAccumulator;
1569 
1570 // Incremented when a GradientTape or accumulator is newly added to a set, and
1571 // used to enforce an ordering between them.
1572 std::atomic_uint_fast64_t tape_nesting_id_counter(0);
1573 
1574 typedef struct {
1575   PyObject_HEAD
1576       /* Type-specific fields go here. */
1577       GradientTape* tape;
1578   // A nesting order between GradientTapes and ForwardAccumulators, used to
1579   // ensure that GradientTapes do not watch the products of outer
1580   // ForwardAccumulators.
1581   tensorflow::int64 nesting_id;
1582 } TFE_Py_Tape;
1583 
TFE_Py_Tape_Delete(PyObject * tape)1584 static void TFE_Py_Tape_Delete(PyObject* tape) {
1585   delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1586   Py_TYPE(tape)->tp_free(tape);
1587 }
1588 
1589 static PyTypeObject TFE_Py_Tape_Type = {
1590     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1591     sizeof(TFE_Py_Tape),                          /* tp_basicsize */
1592     0,                                            /* tp_itemsize */
1593     &TFE_Py_Tape_Delete,                          /* tp_dealloc */
1594 #if PY_VERSION_HEX < 0x03080000
1595     nullptr, /* tp_print */
1596 #else
1597     0, /* tp_vectorcall_offset */
1598 #endif
1599     nullptr,               /* tp_getattr */
1600     nullptr,               /* tp_setattr */
1601     nullptr,               /* tp_reserved */
1602     nullptr,               /* tp_repr */
1603     nullptr,               /* tp_as_number */
1604     nullptr,               /* tp_as_sequence */
1605     nullptr,               /* tp_as_mapping */
1606     nullptr,               /* tp_hash  */
1607     nullptr,               /* tp_call */
1608     nullptr,               /* tp_str */
1609     nullptr,               /* tp_getattro */
1610     nullptr,               /* tp_setattro */
1611     nullptr,               /* tp_as_buffer */
1612     Py_TPFLAGS_DEFAULT,    /* tp_flags */
1613     "TFE_Py_Tape objects", /* tp_doc */
1614 };
1615 
1616 typedef struct {
1617   PyObject_HEAD
1618       /* Type-specific fields go here. */
1619       ForwardAccumulator* accumulator;
1620   // A nesting order between GradientTapes and ForwardAccumulators, used to
1621   // ensure that GradientTapes do not watch the products of outer
1622   // ForwardAccumulators.
1623   tensorflow::int64 nesting_id;
1624 } TFE_Py_ForwardAccumulator;
1625 
TFE_Py_ForwardAccumulatorDelete(PyObject * accumulator)1626 static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
1627   delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
1628   Py_TYPE(accumulator)->tp_free(accumulator);
1629 }
1630 
1631 static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
1632     PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
1633     sizeof(TFE_Py_ForwardAccumulator),                      /* tp_basicsize */
1634     0,                                                      /* tp_itemsize */
1635     &TFE_Py_ForwardAccumulatorDelete,                       /* tp_dealloc */
1636 #if PY_VERSION_HEX < 0x03080000
1637     nullptr, /* tp_print */
1638 #else
1639     0, /* tp_vectorcall_offset */
1640 #endif
1641     nullptr,                             /* tp_getattr */
1642     nullptr,                             /* tp_setattr */
1643     nullptr,                             /* tp_reserved */
1644     nullptr,                             /* tp_repr */
1645     nullptr,                             /* tp_as_number */
1646     nullptr,                             /* tp_as_sequence */
1647     nullptr,                             /* tp_as_mapping */
1648     nullptr,                             /* tp_hash  */
1649     nullptr,                             /* tp_call */
1650     nullptr,                             /* tp_str */
1651     nullptr,                             /* tp_getattro */
1652     nullptr,                             /* tp_setattro */
1653     nullptr,                             /* tp_as_buffer */
1654     Py_TPFLAGS_DEFAULT,                  /* tp_flags */
1655     "TFE_Py_ForwardAccumulator objects", /* tp_doc */
1656 };
1657 
1658 typedef struct {
1659   PyObject_HEAD
1660       /* Type-specific fields go here. */
1661       VariableWatcher* variable_watcher;
1662 } TFE_Py_VariableWatcher;
1663 
TFE_Py_VariableWatcher_Delete(PyObject * variable_watcher)1664 static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
1665   delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
1666       ->variable_watcher;
1667   Py_TYPE(variable_watcher)->tp_free(variable_watcher);
1668 }
1669 
1670 static PyTypeObject TFE_Py_VariableWatcher_Type = {
1671     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
1672     sizeof(TFE_Py_VariableWatcher),                          /* tp_basicsize */
1673     0,                                                       /* tp_itemsize */
1674     &TFE_Py_VariableWatcher_Delete,                          /* tp_dealloc */
1675 #if PY_VERSION_HEX < 0x03080000
1676     nullptr, /* tp_print */
1677 #else
1678     0, /* tp_vectorcall_offset */
1679 #endif
1680     nullptr,                          /* tp_getattr */
1681     nullptr,                          /* tp_setattr */
1682     nullptr,                          /* tp_reserved */
1683     nullptr,                          /* tp_repr */
1684     nullptr,                          /* tp_as_number */
1685     nullptr,                          /* tp_as_sequence */
1686     nullptr,                          /* tp_as_mapping */
1687     nullptr,                          /* tp_hash  */
1688     nullptr,                          /* tp_call */
1689     nullptr,                          /* tp_str */
1690     nullptr,                          /* tp_getattro */
1691     nullptr,                          /* tp_setattro */
1692     nullptr,                          /* tp_as_buffer */
1693     Py_TPFLAGS_DEFAULT,               /* tp_flags */
1694     "TFE_Py_VariableWatcher objects", /* tp_doc */
1695 };
1696 
1697 // Note: in the current design no mutex is needed here because of the python
1698 // GIL, which is always held when any TFE_Py_* methods are called. We should
1699 // revisit this if/when decide to not hold the GIL while manipulating the tape
1700 // stack.
GetTapeSet()1701 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1702   thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
1703       tape_set = nullptr;
1704   if (tape_set == nullptr) {
1705     tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
1706   }
1707   return tape_set.get();
1708 }
1709 
1710 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
GetVariableWatcherSet()1711 GetVariableWatcherSet() {
1712   thread_local std::unique_ptr<
1713       tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
1714       variable_watcher_set = nullptr;
1715   if (variable_watcher_set == nullptr) {
1716     variable_watcher_set.reset(
1717         new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
1718   }
1719   return variable_watcher_set.get();
1720 }
1721 
1722 // A linked hash set, where iteration is in insertion order.
1723 //
1724 // Nested accumulators rely on op recording happening in insertion order, so an
1725 // unordered data structure like CompactPointerSet is not suitable. Outer
1726 // accumulators need to observe operations first so they know to watch the inner
1727 // accumulator's jvp computation.
1728 //
1729 // Not thread safe.
1730 class AccumulatorSet {
1731  public:
1732   // Returns true if `element` was newly inserted, false if it already exists.
insert(TFE_Py_ForwardAccumulator * element)1733   bool insert(TFE_Py_ForwardAccumulator* element) {
1734     if (map_.find(element) != map_.end()) {
1735       return false;
1736     }
1737     ListType::iterator it = ordered_.insert(ordered_.end(), element);
1738     map_.insert(std::make_pair(element, it));
1739     return true;
1740   }
1741 
erase(TFE_Py_ForwardAccumulator * element)1742   void erase(TFE_Py_ForwardAccumulator* element) {
1743     MapType::iterator existing = map_.find(element);
1744     if (existing == map_.end()) {
1745       return;
1746     }
1747     ListType::iterator list_position = existing->second;
1748     map_.erase(existing);
1749     ordered_.erase(list_position);
1750   }
1751 
empty() const1752   bool empty() const { return ordered_.empty(); }
1753 
size() const1754   size_t size() const { return ordered_.size(); }
1755 
1756  private:
1757   typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
1758   typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
1759                                    ListType::iterator>
1760       MapType;
1761 
1762  public:
1763   typedef ListType::const_iterator const_iterator;
1764   typedef ListType::const_reverse_iterator const_reverse_iterator;
1765 
begin() const1766   const_iterator begin() const { return ordered_.begin(); }
end() const1767   const_iterator end() const { return ordered_.end(); }
1768 
rbegin() const1769   const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
rend() const1770   const_reverse_iterator rend() const { return ordered_.rend(); }
1771 
1772  private:
1773   MapType map_;
1774   ListType ordered_;
1775 };
1776 
GetAccumulatorSet()1777 AccumulatorSet* GetAccumulatorSet() {
1778   thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
1779   if (accumulator_set == nullptr) {
1780     accumulator_set.reset(new AccumulatorSet);
1781   }
1782   return accumulator_set.get();
1783 }
1784 
HasAccumulator()1785 inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
1786 
HasGradientTape()1787 inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
1788 
HasAccumulatorOrTape()1789 inline bool HasAccumulatorOrTape() {
1790   return HasGradientTape() || HasAccumulator();
1791 }
1792 
1793 // A safe copy of a set, used for tapes and accumulators. The copy is not
1794 // affected by other python threads changing the set of active tapes.
1795 template <typename ContainerType>
1796 class SafeSetCopy {
1797  public:
SafeSetCopy(const ContainerType & to_copy)1798   explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
1799     for (auto* member : set_copy_) {
1800       Py_INCREF(member);
1801     }
1802   }
1803 
~SafeSetCopy()1804   ~SafeSetCopy() {
1805     for (auto* member : set_copy_) {
1806       Py_DECREF(member);
1807     }
1808   }
1809 
begin() const1810   typename ContainerType::const_iterator begin() const {
1811     return set_copy_.begin();
1812   }
1813 
end() const1814   typename ContainerType::const_iterator end() const { return set_copy_.end(); }
1815 
empty() const1816   bool empty() const { return set_copy_.empty(); }
size() const1817   size_t size() const { return set_copy_.size(); }
1818 
1819  protected:
1820   ContainerType set_copy_;
1821 };
1822 
1823 class SafeTapeSet
1824     : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
1825  public:
SafeTapeSet()1826   SafeTapeSet()
1827       : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
1828             *GetTapeSet()) {}
1829 };
1830 
1831 class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
1832  public:
SafeAccumulatorSet()1833   SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
1834 
rbegin() const1835   typename AccumulatorSet::const_reverse_iterator rbegin() const {
1836     return set_copy_.rbegin();
1837   }
1838 
rend() const1839   typename AccumulatorSet::const_reverse_iterator rend() const {
1840     return set_copy_.rend();
1841   }
1842 };
1843 
1844 class SafeVariableWatcherSet
1845     : public SafeSetCopy<
1846           tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
1847  public:
SafeVariableWatcherSet()1848   SafeVariableWatcherSet()
1849       : SafeSetCopy<
1850             tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
1851             *GetVariableWatcherSet()) {}
1852 };
1853 
ThreadTapeIsStopped()1854 bool* ThreadTapeIsStopped() {
1855   thread_local bool thread_tape_is_stopped{false};
1856   return &thread_tape_is_stopped;
1857 }
1858 
TFE_Py_TapeSetStopOnThread()1859 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1860 
TFE_Py_TapeSetRestartOnThread()1861 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1862 
TFE_Py_TapeSetIsStopped()1863 PyObject* TFE_Py_TapeSetIsStopped() {
1864   if (*ThreadTapeIsStopped()) {
1865     Py_RETURN_TRUE;
1866   }
1867   Py_RETURN_FALSE;
1868 }
1869 
TFE_Py_TapeSetNew(PyObject * persistent,PyObject * watch_accessed_variables)1870 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1871                             PyObject* watch_accessed_variables) {
1872   TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1873   if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1874   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1875   tape->tape = new GradientTape(persistent == Py_True,
1876                                 watch_accessed_variables == Py_True);
1877   Py_INCREF(tape);
1878   tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1879   GetTapeSet()->insert(tape);
1880   return reinterpret_cast<PyObject*>(tape);
1881 }
1882 
TFE_Py_TapeSetAdd(PyObject * tape)1883 void TFE_Py_TapeSetAdd(PyObject* tape) {
1884   Py_INCREF(tape);
1885   TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
1886   if (!GetTapeSet()->insert(tfe_tape).second) {
1887     // Already exists in the tape set.
1888     Py_DECREF(tape);
1889   } else {
1890     tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1891   }
1892 }
1893 
TFE_Py_TapeSetIsEmpty()1894 PyObject* TFE_Py_TapeSetIsEmpty() {
1895   if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
1896     Py_RETURN_TRUE;
1897   }
1898   Py_RETURN_FALSE;
1899 }
1900 
TFE_Py_TapeSetRemove(PyObject * tape)1901 void TFE_Py_TapeSetRemove(PyObject* tape) {
1902   auto* stack = GetTapeSet();
1903   stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1904   // We kept a reference to the tape in the set to ensure it wouldn't get
1905   // deleted under us; cleaning it up here.
1906   Py_DECREF(tape);
1907 }
1908 
MakeIntList(PyObject * list)1909 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
1910   if (list == Py_None) {
1911     return {};
1912   }
1913   PyObject* seq = PySequence_Fast(list, "expected a sequence");
1914   if (seq == nullptr) {
1915     return {};
1916   }
1917   int len = PySequence_Size(list);
1918   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1919   std::vector<tensorflow::int64> tensor_ids;
1920   tensor_ids.reserve(len);
1921   for (int i = 0; i < len; ++i) {
1922     PyObject* item = seq_array[i];
1923 #if PY_MAJOR_VERSION >= 3
1924     if (PyLong_Check(item)) {
1925 #else
1926     if (PyLong_Check(item) || PyInt_Check(item)) {
1927 #endif
1928       int64_t id = MakeInt(item);
1929       tensor_ids.push_back(id);
1930     } else {
1931       tensor_ids.push_back(-1);
1932     }
1933   }
1934   Py_DECREF(seq);
1935   return tensor_ids;
1936 }
1937 
1938 // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
1939 // null. Returns true on success and false on a Python exception.
1940 bool TensorShapesAndDtypes(PyObject* tensors,
1941                            std::vector<tensorflow::int64>* tensor_ids,
1942                            std::vector<tensorflow::DataType>* dtypes) {
1943   tensorflow::Safe_PyObjectPtr seq(
1944       PySequence_Fast(tensors, "expected a sequence"));
1945   if (seq == nullptr) {
1946     return false;
1947   }
1948   int len = PySequence_Fast_GET_SIZE(seq.get());
1949   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
1950   tensor_ids->reserve(len);
1951   dtypes->reserve(len);
1952   for (int i = 0; i < len; ++i) {
1953     PyObject* item = seq_array[i];
1954     tensor_ids->push_back(FastTensorId(item));
1955     dtypes->push_back(tensorflow::PyTensor_DataType(item));
1956   }
1957   return true;
1958 }
1959 
1960 bool TapeCouldPossiblyRecord(PyObject* tensors) {
1961   if (tensors == Py_None) {
1962     return false;
1963   }
1964   if (*ThreadTapeIsStopped()) {
1965     return false;
1966   }
1967   if (!HasAccumulatorOrTape()) {
1968     return false;
1969   }
1970   return true;
1971 }
1972 
1973 bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
1974 
1975 bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
1976 
1977 PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
1978   if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
1979     Py_RETURN_FALSE;
1980   }
1981   // TODO(apassos) consider not building a list and changing the API to check
1982   // each tensor individually.
1983   std::vector<tensorflow::int64> tensor_ids;
1984   std::vector<tensorflow::DataType> dtypes;
1985   if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
1986     return nullptr;
1987   }
1988   auto tape_set = *GetTapeSet();
1989   for (TFE_Py_Tape* tape : tape_set) {
1990     if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
1991       Py_RETURN_TRUE;
1992     }
1993   }
1994 
1995   Py_RETURN_FALSE;
1996 }
1997 
1998 PyObject* TFE_Py_ForwardAccumulatorPushState() {
1999   auto forward_accumulators = *GetAccumulatorSet();
2000   for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2001     accumulator->accumulator->PushState();
2002   }
2003   Py_RETURN_NONE;
2004 }
2005 
2006 PyObject* TFE_Py_ForwardAccumulatorPopState() {
2007   auto forward_accumulators = *GetAccumulatorSet();
2008   for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2009     accumulator->accumulator->PopState();
2010   }
2011   Py_RETURN_NONE;
2012 }
2013 
2014 PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
2015   if (!TapeCouldPossiblyRecord(tensors)) {
2016     return GetPythonObjectFromInt(0);
2017   }
2018   std::vector<tensorflow::int64> tensor_ids;
2019   std::vector<tensorflow::DataType> dtypes;
2020   if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2021     return nullptr;
2022   }
2023 
2024   // If there is a persistent tape watching, or if there are multiple tapes
2025   // watching, we'll return immediately indicating that higher-order tape
2026   // gradients are possible.
2027   bool some_tape_watching = false;
2028   if (CouldBackprop()) {
2029     auto tape_set = *GetTapeSet();
2030     for (TFE_Py_Tape* tape : tape_set) {
2031       if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2032         if (tape->tape->IsPersistent() || some_tape_watching) {
2033           // Either this is the second tape watching, or this tape is
2034           // persistent: higher-order gradients are possible.
2035           return GetPythonObjectFromInt(2);
2036         }
2037         some_tape_watching = true;
2038       }
2039     }
2040   }
2041   if (CouldForwardprop()) {
2042     auto forward_accumulators = *GetAccumulatorSet();
2043     for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2044       if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
2045         if (some_tape_watching) {
2046           // This is the second tape watching: higher-order gradients are
2047           // possible. Note that there's no equivalent of persistence for
2048           // forward-mode.
2049           return GetPythonObjectFromInt(2);
2050         }
2051         some_tape_watching = true;
2052       }
2053     }
2054   }
2055   if (some_tape_watching) {
2056     // There's exactly one non-persistent tape. The user can request first-order
2057     // gradients but won't be able to get higher-order tape gradients.
2058     return GetPythonObjectFromInt(1);
2059   } else {
2060     // There are no tapes. The user can't request tape gradients.
2061     return GetPythonObjectFromInt(0);
2062   }
2063 }
2064 
2065 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
2066   if (!CouldBackprop()) {
2067     return;
2068   }
2069   int64_t tensor_id = FastTensorId(tensor);
2070   if (PyErr_Occurred()) {
2071     return;
2072   }
2073   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
2074 }
2075 
2076 bool ListContainsNone(PyObject* list) {
2077   if (list == Py_None) return true;
2078   tensorflow::Safe_PyObjectPtr seq(
2079       PySequence_Fast(list, "expected a sequence"));
2080   if (seq == nullptr) {
2081     return false;
2082   }
2083 
2084   int len = PySequence_Size(list);
2085   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2086   for (int i = 0; i < len; ++i) {
2087     PyObject* item = seq_array[i];
2088     if (item == Py_None) return true;
2089   }
2090 
2091   return false;
2092 }
2093 
2094 // As an optimization, the tape generally keeps only the shape and dtype of
2095 // tensors, and uses this information to generate ones/zeros tensors. However,
2096 // some tensors require OnesLike/ZerosLike because their gradients do not match
2097 // their inference shape/dtype.
2098 bool DTypeNeedsHandleData(tensorflow::DataType dtype) {
2099   return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE;
2100 }
2101 
2102 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
2103   if (EagerTensor_CheckExact(tensor)) {
2104     tensorflow::ImmediateExecutionTensorHandle* handle =
2105         tensorflow::unwrap(EagerTensor_Handle(tensor));
2106     int64_t id = PyEagerTensor_ID(tensor);
2107     tensorflow::DataType dtype =
2108         static_cast<tensorflow::DataType>(handle->DataType());
2109     if (DTypeNeedsHandleData(dtype)) {
2110       return PyTapeTensor(id, dtype, tensor);
2111     }
2112 
2113     tensorflow::TensorShape tensor_shape;
2114     int num_dims;
2115     tensorflow::Status status = handle->NumDims(&num_dims);
2116     if (status.ok()) {
2117       for (int i = 0; i < num_dims; ++i) {
2118         int64_t dim_size;
2119         status = handle->Dim(i, &dim_size);
2120         if (!status.ok()) break;
2121         tensor_shape.AddDim(dim_size);
2122       }
2123     }
2124 
2125     if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2126       return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2127                           tensorflow::TensorShape({}));
2128     } else {
2129       return PyTapeTensor(id, dtype, tensor_shape);
2130     }
2131   }
2132   int64_t id = FastTensorId(tensor);
2133   if (PyErr_Occurred()) {
2134     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2135                         tensorflow::TensorShape({}));
2136   }
2137   PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
2138   PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
2139   Py_DECREF(dtype_object);
2140   tensorflow::DataType dtype =
2141       static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
2142   Py_DECREF(dtype_enum);
2143   if (PyErr_Occurred()) {
2144     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2145                         tensorflow::TensorShape({}));
2146   }
2147   static char _shape_tuple[] = "_shape_tuple";
2148   tensorflow::Safe_PyObjectPtr shape_tuple(
2149       PyObject_CallMethod(tensor, _shape_tuple, nullptr));
2150   if (PyErr_Occurred()) {
2151     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2152                         tensorflow::TensorShape({}));
2153   }
2154 
2155   if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) {
2156     return PyTapeTensor(id, dtype, tensor);
2157   }
2158 
2159   auto l = MakeIntList(shape_tuple.get());
2160   // Replace -1, which represents accidental Nones which can occur in graph mode
2161   // and can cause errors in shape construction with 0s.
2162   for (auto& c : l) {
2163     if (c < 0) {
2164       c = 0;
2165     }
2166   }
2167   tensorflow::TensorShape shape(l);
2168   return PyTapeTensor(id, dtype, shape);
2169 }
2170 
2171 // Populates output_info from output_seq, which must come from PySequence_Fast.
2172 //
2173 // Does not take ownership of output_seq. Returns true on success and false if a
2174 // Python exception has been set.
2175 bool TapeTensorsFromTensorSequence(PyObject* output_seq,
2176                                    std::vector<PyTapeTensor>* output_info) {
2177   Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
2178   PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2179   output_info->reserve(output_len);
2180   for (Py_ssize_t i = 0; i < output_len; ++i) {
2181     output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
2182     if (PyErr_Occurred() != nullptr) {
2183       return false;
2184     }
2185   }
2186   return true;
2187 }
2188 
2189 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
2190   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2191   if (seq == nullptr) {
2192     return {};
2193   }
2194   int len = PySequence_Fast_GET_SIZE(seq);
2195   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2196   std::vector<tensorflow::int64> list;
2197   list.reserve(len);
2198   for (int i = 0; i < len; ++i) {
2199     PyObject* tensor = seq_array[i];
2200     list.push_back(FastTensorId(tensor));
2201     if (PyErr_Occurred()) {
2202       Py_DECREF(seq);
2203       return list;
2204     }
2205   }
2206   Py_DECREF(seq);
2207   return list;
2208 }
2209 
2210 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
2211   if (!CouldBackprop()) {
2212     return;
2213   }
2214   for (TFE_Py_Tape* tape : SafeTapeSet()) {
2215     tape->tape->VariableAccessed(variable);
2216   }
2217 }
2218 
2219 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
2220   if (!CouldBackprop()) {
2221     return;
2222   }
2223   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
2224 }
2225 
2226 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
2227   return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
2228 }
2229 
2230 PyObject* TFE_Py_VariableWatcherNew() {
2231   TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
2232   if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
2233   TFE_Py_VariableWatcher* variable_watcher =
2234       PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
2235   variable_watcher->variable_watcher = new VariableWatcher();
2236   Py_INCREF(variable_watcher);
2237   GetVariableWatcherSet()->insert(variable_watcher);
2238   return reinterpret_cast<PyObject*>(variable_watcher);
2239 }
2240 
2241 void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
2242   auto* stack = GetVariableWatcherSet();
2243   stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
2244   // We kept a reference to the variable watcher in the set to ensure it
2245   // wouldn't get deleted under us; cleaning it up here.
2246   Py_DECREF(variable_watcher);
2247 }
2248 
2249 void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
2250   for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
2251     variable_watcher->variable_watcher->WatchVariable(variable);
2252   }
2253 }
2254 
2255 PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
2256   return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
2257       ->variable_watcher->GetVariablesAsPyTuple();
2258 }
2259 
2260 namespace {
2261 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
2262   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2263   if (seq == nullptr) {
2264     return {};
2265   }
2266   int len = PySequence_Fast_GET_SIZE(seq);
2267   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2268   std::vector<tensorflow::DataType> list;
2269   list.reserve(len);
2270   for (int i = 0; i < len; ++i) {
2271     PyObject* tensor = seq_array[i];
2272     list.push_back(tensorflow::PyTensor_DataType(tensor));
2273   }
2274   Py_DECREF(seq);
2275   return list;
2276 }
2277 
2278 PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
2279                                            PyObject* weak_tensor_ref) {
2280   int64_t parsed_tensor_id = MakeInt(tensor_id);
2281   for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
2282     accumulator->accumulator->DeleteGradient(parsed_tensor_id);
2283   }
2284   Py_DECREF(weak_tensor_ref);
2285   Py_DECREF(tensor_id);
2286   Py_INCREF(Py_None);
2287   return Py_None;
2288 }
2289 
2290 static PyMethodDef forward_accumulator_delete_gradient_method_def = {
2291     "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
2292     METH_O, "ForwardAccumulatorDeleteGradient"};
2293 
2294 void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) {
2295   tensorflow::Safe_PyObjectPtr callback(
2296       PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
2297                       PyLong_FromLong(tensor_id)));
2298   // We need to keep a reference to the weakref active if we want our callback
2299   // called. The callback itself now owns the weakref object and the tensor ID
2300   // object.
2301   PyWeakref_NewRef(tensor, callback.get());
2302 }
2303 
2304 void TapeSetRecordBackprop(
2305     const string& op_type, const std::vector<PyTapeTensor>& output_info,
2306     const std::vector<tensorflow::int64>& input_ids,
2307     const std::vector<tensorflow::DataType>& input_dtypes,
2308     const std::function<PyBackwardFunction*()>& backward_function_getter,
2309     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2310     tensorflow::uint64 max_gradient_tape_id) {
2311   if (!CouldBackprop()) {
2312     return;
2313   }
2314   for (TFE_Py_Tape* tape : SafeTapeSet()) {
2315     if (tape->nesting_id < max_gradient_tape_id) {
2316       tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
2317                                   backward_function_getter,
2318                                   backward_function_killer);
2319     }
2320   }
2321 }
2322 
2323 bool TapeSetRecordForwardprop(
2324     const string& op_type, PyObject* output_seq,
2325     const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
2326     const std::vector<tensorflow::int64>& input_ids,
2327     const std::vector<tensorflow::DataType>& input_dtypes,
2328     const std::function<PyBackwardFunction*()>& backward_function_getter,
2329     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2330     const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
2331     PyObject* forwardprop_output_indices,
2332     tensorflow::uint64* max_gradient_tape_id) {
2333   *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
2334   if (!CouldForwardprop()) {
2335     return true;
2336   }
2337   auto accumulator_set = SafeAccumulatorSet();
2338   tensorflow::Safe_PyObjectPtr input_seq(
2339       PySequence_Fast(input_tensors, "expected a sequence of tensors"));
2340   if (input_seq == nullptr || PyErr_Occurred()) return false;
2341   Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
2342   PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2343   for (int i = 0; i < output_info.size(); ++i) {
2344     RegisterForwardAccumulatorCleanup(output_seq_array[i],
2345                                       output_info[i].GetID());
2346   }
2347   if (forwardprop_output_indices != nullptr &&
2348       forwardprop_output_indices != Py_None) {
2349     tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
2350         forwardprop_output_indices, "Expected a sequence of indices"));
2351     if (indices_fast == nullptr || PyErr_Occurred()) {
2352       return false;
2353     }
2354     if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
2355         accumulator_set.size()) {
2356       MaybeRaiseExceptionFromStatus(
2357           tensorflow::errors::Internal(
2358               "Accumulators were added or removed from the active set "
2359               "between packing and unpacking."),
2360           nullptr);
2361     }
2362     PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
2363     Py_ssize_t accumulator_index = 0;
2364     for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
2365          it != accumulator_set.rend(); ++it, ++accumulator_index) {
2366       tensorflow::Safe_PyObjectPtr jvp_index_seq(
2367           PySequence_Fast(indices_fast_array[accumulator_index],
2368                           "Expected a sequence of jvp indices."));
2369       if (jvp_index_seq == nullptr || PyErr_Occurred()) {
2370         return false;
2371       }
2372       Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
2373       PyObject** jvp_index_seq_array =
2374           PySequence_Fast_ITEMS(jvp_index_seq.get());
2375       for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
2376         PyObject* tuple = jvp_index_seq_array[jvp_index];
2377         int64_t primal_tensor_id =
2378             output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
2379         (*it)->accumulator->Watch(
2380             primal_tensor_id,
2381             output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
2382       }
2383     }
2384   } else {
2385     std::vector<PyTapeTensor> input_info;
2386     input_info.reserve(input_len);
2387     PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
2388     for (Py_ssize_t i = 0; i < input_len; ++i) {
2389       input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
2390     }
2391     for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
2392       tensorflow::Status status = accumulator->accumulator->Accumulate(
2393           op_type, input_info, output_info, input_ids, input_dtypes,
2394           forward_function, backward_function_getter, backward_function_killer);
2395       if (PyErr_Occurred()) return false;  // Don't swallow Python exceptions.
2396       if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2397         return false;
2398       }
2399       if (accumulator->accumulator->BusyAccumulating()) {
2400         // Ensure inner accumulators don't see outer accumulators' jvps. This
2401         // mostly happens on its own, with some potentially surprising
2402         // exceptions, so the blanket policy is for consistency.
2403         *max_gradient_tape_id = accumulator->nesting_id;
2404         break;
2405       }
2406     }
2407   }
2408   return true;
2409 }
2410 
2411 PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
2412   PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
2413   for (int i = 0; i < input_tangents.size(); ++i) {
2414     PyObject* element;
2415     if (input_tangents[i] == nullptr) {
2416       element = Py_None;
2417     } else {
2418       element = input_tangents[i];
2419     }
2420     Py_INCREF(element);
2421     PyTuple_SET_ITEM(py_input_tangents, i, element);
2422   }
2423   return py_input_tangents;
2424 }
2425 
2426 tensorflow::Status ParseTangentOutputs(
2427     PyObject* user_output, std::vector<PyObject*>* output_tangents) {
2428   if (user_output == Py_None) {
2429     // No connected gradients.
2430     return tensorflow::Status::OK();
2431   }
2432   tensorflow::Safe_PyObjectPtr fast_result(
2433       PySequence_Fast(user_output, "expected a sequence of forward gradients"));
2434   if (fast_result == nullptr) {
2435     return tensorflow::errors::InvalidArgument(
2436         "forward gradient function did not return a sequence.");
2437   }
2438   int len = PySequence_Fast_GET_SIZE(fast_result.get());
2439   PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
2440   output_tangents->reserve(len);
2441   for (int i = 0; i < len; ++i) {
2442     PyObject* item = fast_result_array[i];
2443     if (item == Py_None) {
2444       output_tangents->push_back(nullptr);
2445     } else {
2446       Py_INCREF(item);
2447       output_tangents->push_back(item);
2448     }
2449   }
2450   return tensorflow::Status::OK();
2451 }
2452 
2453 // Calls the registered forward_gradient_function, computing `output_tangents`
2454 // from `input_tangents`. `output_tangents` must not be null.
2455 //
2456 // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
2457 // the forward function is being called.
2458 tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
2459                                    PyObject* inputs, PyObject* results,
2460                                    const std::vector<PyObject*>& input_tangents,
2461                                    std::vector<PyObject*>* output_tangents,
2462                                    bool use_batch) {
2463   if (forward_gradient_function == nullptr) {
2464     return tensorflow::errors::Internal(
2465         "No forward gradient function registered.");
2466   }
2467   tensorflow::Safe_PyObjectPtr py_input_tangents(
2468       TangentsAsPyTuple(input_tangents));
2469 
2470   // Normalize the input sequence to a tuple so it works with function
2471   // caching; otherwise it may be an opaque _InputList object.
2472   tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
2473   PyObject* to_batch = (use_batch) ? Py_True : Py_False;
2474   tensorflow::Safe_PyObjectPtr callback_args(
2475       Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
2476                     py_input_tangents.get(), to_batch));
2477   tensorflow::Safe_PyObjectPtr py_result(
2478       PyObject_CallObject(forward_gradient_function, callback_args.get()));
2479   if (py_result == nullptr || PyErr_Occurred()) {
2480     return tensorflow::errors::Internal(
2481         "forward gradient function threw exceptions");
2482   }
2483   return ParseTangentOutputs(py_result.get(), output_tangents);
2484 }
2485 
2486 // Like CallJVPFunction, but calls a pre-bound forward function.
2487 // These are passed in from a record_gradient argument.
2488 tensorflow::Status CallOpSpecificJVPFunction(
2489     PyObject* op_specific_forward_function,
2490     const std::vector<PyObject*>& input_tangents,
2491     std::vector<PyObject*>* output_tangents) {
2492   tensorflow::Safe_PyObjectPtr py_input_tangents(
2493       TangentsAsPyTuple(input_tangents));
2494 
2495   tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
2496       op_specific_forward_function, py_input_tangents.get()));
2497   if (py_result == nullptr || PyErr_Occurred()) {
2498     return tensorflow::errors::Internal(
2499         "forward gradient function threw exceptions");
2500   }
2501   return ParseTangentOutputs(py_result.get(), output_tangents);
2502 }
2503 
2504 bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
2505   if (PyBytes_Check(op_type)) {
2506     *op_type_string = PyBytes_AsString(op_type);
2507   } else if (PyUnicode_Check(op_type)) {
2508 #if PY_MAJOR_VERSION >= 3
2509     *op_type_string = PyUnicode_AsUTF8(op_type);
2510 #else
2511     PyObject* py_str = PyUnicode_AsUTF8String(op_type);
2512     if (py_str == nullptr) {
2513       return false;
2514     }
2515     *op_type_string = PyBytes_AS_STRING(py_str);
2516     Py_DECREF(py_str);
2517 #endif
2518   } else {
2519     PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
2520     return false;
2521   }
2522   return true;
2523 }
2524 
2525 bool TapeSetRecordOperation(
2526     PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
2527     const std::vector<tensorflow::int64>& input_ids,
2528     const std::vector<tensorflow::DataType>& input_dtypes,
2529     const std::function<PyBackwardFunction*()>& backward_function_getter,
2530     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2531     const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
2532   std::vector<PyTapeTensor> output_info;
2533   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2534       output_tensors, "expected a sequence of integer tensor ids"));
2535   if (PyErr_Occurred() ||
2536       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2537     return false;
2538   }
2539   string op_type_str;
2540   if (!ParseOpTypeString(op_type, &op_type_str)) {
2541     return false;
2542   }
2543   tensorflow::uint64 max_gradient_tape_id;
2544   if (!TapeSetRecordForwardprop(
2545           op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2546           input_dtypes, backward_function_getter, backward_function_killer,
2547           forward_function, nullptr /* No special-cased jvps. */,
2548           &max_gradient_tape_id)) {
2549     return false;
2550   }
2551   TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2552                         backward_function_getter, backward_function_killer,
2553                         max_gradient_tape_id);
2554   return true;
2555 }
2556 }  // namespace
2557 
2558 PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
2559                                         PyObject* output_tensors,
2560                                         PyObject* input_tensors,
2561                                         PyObject* backward_function,
2562                                         PyObject* forward_function) {
2563   if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
2564     Py_RETURN_NONE;
2565   }
2566   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2567   if (PyErr_Occurred()) return nullptr;
2568 
2569   std::vector<tensorflow::DataType> input_dtypes =
2570       MakeTensorDtypeList(input_tensors);
2571   if (PyErr_Occurred()) return nullptr;
2572 
2573   std::function<PyBackwardFunction*()> backward_function_getter(
2574       [backward_function]() {
2575         Py_INCREF(backward_function);
2576         PyBackwardFunction* function = new PyBackwardFunction(
2577             [backward_function](PyObject* out_grads,
2578                                 const std::vector<tensorflow::int64>& unused) {
2579               return PyObject_CallObject(backward_function, out_grads);
2580             });
2581         return function;
2582       });
2583   std::function<void(PyBackwardFunction*)> backward_function_killer(
2584       [backward_function](PyBackwardFunction* py_backward_function) {
2585         Py_DECREF(backward_function);
2586         delete py_backward_function;
2587       });
2588 
2589   if (forward_function == Py_None) {
2590     if (!TapeSetRecordOperation(
2591             op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2592             backward_function_getter, backward_function_killer,
2593             nullptr /* No special-cased forward function */)) {
2594       return nullptr;
2595     }
2596   } else {
2597     tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
2598         [forward_function](const std::vector<PyObject*>& input_tangents,
2599                            std::vector<PyObject*>* output_tangents,
2600                            bool use_batch = false) {
2601           return CallOpSpecificJVPFunction(forward_function, input_tangents,
2602                                            output_tangents);
2603         });
2604     if (!TapeSetRecordOperation(
2605             op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2606             backward_function_getter, backward_function_killer,
2607             &wrapped_forward_function)) {
2608       return nullptr;
2609     }
2610   }
2611   Py_RETURN_NONE;
2612 }
2613 
2614 PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
2615     PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
2616     PyObject* backward_function, PyObject* forwardprop_output_indices) {
2617   if (!HasAccumulator() || *ThreadTapeIsStopped()) {
2618     Py_RETURN_NONE;
2619   }
2620   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2621   if (PyErr_Occurred()) return nullptr;
2622 
2623   std::vector<tensorflow::DataType> input_dtypes =
2624       MakeTensorDtypeList(input_tensors);
2625   if (PyErr_Occurred()) return nullptr;
2626 
2627   std::function<PyBackwardFunction*()> backward_function_getter(
2628       [backward_function]() {
2629         Py_INCREF(backward_function);
2630         PyBackwardFunction* function = new PyBackwardFunction(
2631             [backward_function](PyObject* out_grads,
2632                                 const std::vector<tensorflow::int64>& unused) {
2633               return PyObject_CallObject(backward_function, out_grads);
2634             });
2635         return function;
2636       });
2637   std::function<void(PyBackwardFunction*)> backward_function_killer(
2638       [backward_function](PyBackwardFunction* py_backward_function) {
2639         Py_DECREF(backward_function);
2640         delete py_backward_function;
2641       });
2642   std::vector<PyTapeTensor> output_info;
2643   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2644       output_tensors, "expected a sequence of integer tensor ids"));
2645   if (PyErr_Occurred() ||
2646       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2647     return nullptr;
2648   }
2649   string op_type_str;
2650   if (!ParseOpTypeString(op_type, &op_type_str)) {
2651     return nullptr;
2652   }
2653   tensorflow::uint64 max_gradient_tape_id;
2654   if (!TapeSetRecordForwardprop(
2655           op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2656           input_dtypes, backward_function_getter, backward_function_killer,
2657           nullptr /* no special-cased forward function */,
2658           forwardprop_output_indices, &max_gradient_tape_id)) {
2659     return nullptr;
2660   }
2661   Py_RETURN_NONE;
2662 }
2663 
2664 PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
2665                                                 PyObject* output_tensors,
2666                                                 PyObject* input_tensors,
2667                                                 PyObject* backward_function) {
2668   if (!CouldBackprop()) {
2669     Py_RETURN_NONE;
2670   }
2671   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2672   if (PyErr_Occurred()) return nullptr;
2673 
2674   std::vector<tensorflow::DataType> input_dtypes =
2675       MakeTensorDtypeList(input_tensors);
2676   if (PyErr_Occurred()) return nullptr;
2677 
2678   std::function<PyBackwardFunction*()> backward_function_getter(
2679       [backward_function]() {
2680         Py_INCREF(backward_function);
2681         PyBackwardFunction* function = new PyBackwardFunction(
2682             [backward_function](PyObject* out_grads,
2683                                 const std::vector<tensorflow::int64>& unused) {
2684               return PyObject_CallObject(backward_function, out_grads);
2685             });
2686         return function;
2687       });
2688   std::function<void(PyBackwardFunction*)> backward_function_killer(
2689       [backward_function](PyBackwardFunction* py_backward_function) {
2690         Py_DECREF(backward_function);
2691         delete py_backward_function;
2692       });
2693   std::vector<PyTapeTensor> output_info;
2694   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2695       output_tensors, "expected a sequence of integer tensor ids"));
2696   if (PyErr_Occurred() ||
2697       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2698     return nullptr;
2699   }
2700   string op_type_str;
2701   if (!ParseOpTypeString(op_type, &op_type_str)) {
2702     return nullptr;
2703   }
2704   TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2705                         backward_function_getter, backward_function_killer,
2706                         // No filtering based on relative ordering with forward
2707                         // accumulators.
2708                         std::numeric_limits<tensorflow::uint64>::max());
2709   Py_RETURN_NONE;
2710 }
2711 
2712 void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) {
2713   for (TFE_Py_Tape* tape : *GetTapeSet()) {
2714     tape->tape->DeleteTrace(tensor_id);
2715   }
2716 }
2717 
2718 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
2719   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2720   if (seq == nullptr) {
2721     return {};
2722   }
2723   int len = PySequence_Fast_GET_SIZE(seq);
2724   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2725   std::vector<PyObject*> list(seq_array, seq_array + len);
2726   Py_DECREF(seq);
2727   return list;
2728 }
2729 
2730 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
2731                               PyObject* sources, PyObject* output_gradients,
2732                               PyObject* sources_raw,
2733                               PyObject* unconnected_gradients,
2734                               TF_Status* status) {
2735   TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
2736   if (!tape_obj->tape->IsPersistent()) {
2737     auto* tape_set = GetTapeSet();
2738     if (tape_set->find(tape_obj) != tape_set->end()) {
2739       PyErr_SetString(PyExc_RuntimeError,
2740                       "gradient() cannot be invoked within the "
2741                       "GradientTape context (i.e., while operations are being "
2742                       "recorded). Either move the call to gradient() to be "
2743                       "outside the 'with tf.GradientTape' block, or "
2744                       "use a persistent tape: "
2745                       "'with tf.GradientTape(persistent=true)'");
2746       return nullptr;
2747     }
2748   }
2749 
2750   std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
2751   if (PyErr_Occurred()) {
2752     return nullptr;
2753   }
2754   std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
2755   if (PyErr_Occurred()) {
2756     return nullptr;
2757   }
2758   tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
2759                                                           sources_vec.end());
2760 
2761   tensorflow::Safe_PyObjectPtr seq =
2762       tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
2763   int len = PySequence_Fast_GET_SIZE(seq.get());
2764   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2765   std::unordered_map<tensorflow::int64, PyTapeTensor>
2766       source_tensors_that_are_targets;
2767   for (int i = 0; i < len; ++i) {
2768     int64_t target_id = target_vec[i];
2769     if (sources_set.find(target_id) != sources_set.end()) {
2770       auto tensor = seq_array[i];
2771       source_tensors_that_are_targets.insert(
2772           std::make_pair(target_id, TapeTensorFromTensor(tensor)));
2773     }
2774     if (PyErr_Occurred()) {
2775       return nullptr;
2776     }
2777   }
2778   if (PyErr_Occurred()) {
2779     return nullptr;
2780   }
2781 
2782   std::vector<PyObject*> outgrad_vec;
2783   if (output_gradients != Py_None) {
2784     outgrad_vec = MakeTensorList(output_gradients);
2785     if (PyErr_Occurred()) {
2786       return nullptr;
2787     }
2788     for (PyObject* tensor : outgrad_vec) {
2789       // Calling the backward function will eat a reference to the tensors in
2790       // outgrad_vec, so we need to increase their reference count.
2791       Py_INCREF(tensor);
2792     }
2793   }
2794   std::vector<PyObject*> result(sources_vec.size());
2795   status->status = tape_obj->tape->ComputeGradient(
2796       *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
2797       outgrad_vec, absl::MakeSpan(result));
2798   if (!status->status.ok()) {
2799     if (PyErr_Occurred()) {
2800       // Do not propagate the erroneous status as that would swallow the
2801       // exception which caused the problem.
2802       status->status = tensorflow::Status::OK();
2803     }
2804     return nullptr;
2805   }
2806 
2807   bool unconnected_gradients_zero =
2808       strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
2809   std::vector<PyObject*> sources_obj;
2810   if (unconnected_gradients_zero) {
2811     // Uses the "raw" sources here so it can properly make a zeros tensor even
2812     // if there are resource variables as sources.
2813     sources_obj = MakeTensorList(sources_raw);
2814   }
2815 
2816   if (!result.empty()) {
2817     PyObject* py_result = PyList_New(result.size());
2818     tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
2819     for (int i = 0; i < result.size(); ++i) {
2820       if (result[i] == nullptr) {
2821         if (unconnected_gradients_zero) {
2822           // generate a zeros tensor in the shape of sources[i]
2823           tensorflow::DataType dtype =
2824               tensorflow::PyTensor_DataType(sources_obj[i]);
2825           PyTapeTensor tensor =
2826               PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
2827           result[i] = tensor.ZerosLike();
2828         } else {
2829           Py_INCREF(Py_None);
2830           result[i] = Py_None;
2831         }
2832       } else if (seen_results.find(result[i]) != seen_results.end()) {
2833         Py_INCREF(result[i]);
2834       }
2835       seen_results.insert(result[i]);
2836       PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
2837     }
2838     return py_result;
2839   }
2840   return PyList_New(0);
2841 }
2842 
2843 PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
2844   TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
2845   if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
2846   TFE_Py_ForwardAccumulator* accumulator =
2847       PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
2848   if (py_vspace == nullptr) {
2849     MaybeRaiseExceptionFromStatus(
2850         tensorflow::errors::Internal(
2851             "ForwardAccumulator requires a PyVSpace to be registered."),
2852         nullptr);
2853   }
2854   accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
2855   return reinterpret_cast<PyObject*>(accumulator);
2856 }
2857 
2858 PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
2859   TFE_Py_ForwardAccumulator* c_accumulator(
2860       reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2861   c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
2862   if (GetAccumulatorSet()->insert(c_accumulator)) {
2863     Py_INCREF(accumulator);
2864     Py_RETURN_NONE;
2865   } else {
2866     MaybeRaiseExceptionFromStatus(
2867         tensorflow::errors::Internal(
2868             "A ForwardAccumulator was added to the active set twice."),
2869         nullptr);
2870     return nullptr;
2871   }
2872 }
2873 
2874 void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
2875   GetAccumulatorSet()->erase(
2876       reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2877   Py_DECREF(accumulator);
2878 }
2879 
2880 void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
2881                                     PyObject* tangent) {
2882   int64_t tensor_id = FastTensorId(tensor);
2883   reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2884       ->accumulator->Watch(tensor_id, tangent);
2885   RegisterForwardAccumulatorCleanup(tensor, tensor_id);
2886 }
2887 
2888 // Returns a new reference to the JVP Tensor.
2889 PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
2890                                        PyObject* tensor) {
2891   PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2892                       ->accumulator->FetchJVP(FastTensorId(tensor));
2893   if (jvp == nullptr) {
2894     jvp = Py_None;
2895   }
2896   Py_INCREF(jvp);
2897   return jvp;
2898 }
2899 
2900 PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
2901   if (!TapeCouldPossiblyRecord(tensors)) {
2902     tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
2903     tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
2904     return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
2905   }
2906   auto accumulators = *GetAccumulatorSet();
2907   tensorflow::Safe_PyObjectPtr tensors_fast(
2908       PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
2909   if (tensors_fast == nullptr || PyErr_Occurred()) {
2910     return nullptr;
2911   }
2912   std::vector<tensorflow::int64> augmented_input_ids;
2913   int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
2914   PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
2915   for (Py_ssize_t position = 0; position < len; ++position) {
2916     PyObject* input = tensors_fast_array[position];
2917     if (input == Py_None) {
2918       continue;
2919     }
2920     tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
2921     if (input_dtype == tensorflow::DT_INVALID) {
2922       return nullptr;
2923     }
2924     augmented_input_ids.push_back(FastTensorId(input));
2925   }
2926   if (PyErr_Occurred()) {
2927     return nullptr;
2928   }
2929   // Find the innermost accumulator such that all outer accumulators are
2930   // recording. Any more deeply nested accumulators will not have their JVPs
2931   // saved.
2932   AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
2933   for (; innermost_all_recording != accumulators.end();
2934        ++innermost_all_recording) {
2935     if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
2936       break;
2937     }
2938   }
2939   AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
2940       innermost_all_recording);
2941 
2942   bool saving_jvps = false;
2943   tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
2944   std::vector<PyObject*> new_tensors;
2945   Py_ssize_t accumulator_index = 0;
2946   // Start with the innermost accumulators to give outer accumulators a chance
2947   // to find their higher-order JVPs.
2948   for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
2949        it != accumulators.rend(); ++it, ++accumulator_index) {
2950     std::vector<tensorflow::int64> new_input_ids;
2951     std::vector<std::pair<tensorflow::int64, tensorflow::int64>>
2952         accumulator_indices;
2953     if (it == reverse_innermost_all_recording) {
2954       saving_jvps = true;
2955     }
2956     if (saving_jvps) {
2957       for (int input_index = 0; input_index < augmented_input_ids.size();
2958            ++input_index) {
2959         int64_t existing_input = augmented_input_ids[input_index];
2960         PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
2961         if (jvp != nullptr) {
2962           new_tensors.push_back(jvp);
2963           new_input_ids.push_back(FastTensorId(jvp));
2964           accumulator_indices.emplace_back(
2965               input_index,
2966               augmented_input_ids.size() + new_input_ids.size() - 1);
2967         }
2968       }
2969     }
2970     tensorflow::Safe_PyObjectPtr accumulator_indices_py(
2971         PyTuple_New(accumulator_indices.size()));
2972     for (int i = 0; i < accumulator_indices.size(); ++i) {
2973       tensorflow::Safe_PyObjectPtr from_index(
2974           GetPythonObjectFromInt(accumulator_indices[i].first));
2975       tensorflow::Safe_PyObjectPtr to_index(
2976           GetPythonObjectFromInt(accumulator_indices[i].second));
2977       PyTuple_SetItem(accumulator_indices_py.get(), i,
2978                       PyTuple_Pack(2, from_index.get(), to_index.get()));
2979     }
2980     PyTuple_SetItem(all_indices.get(), accumulator_index,
2981                     accumulator_indices_py.release());
2982     augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
2983                                new_input_ids.end());
2984   }
2985 
2986   tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
2987   for (int i = 0; i < new_tensors.size(); ++i) {
2988     PyObject* jvp = new_tensors[i];
2989     Py_INCREF(jvp);
2990     PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
2991   }
2992   return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
2993 }
2994 
2995 namespace {
2996 
2997 // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
2998 enum FastPathExecuteArgIndex {
2999   FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
3000   FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
3001   FAST_PATH_EXECUTE_ARG_NAME = 2,
3002   FAST_PATH_EXECUTE_ARG_INPUT_START = 3
3003 };
3004 
3005 PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
3006 #if PY_MAJOR_VERSION >= 3
3007   return PyUnicode_FromStringAndSize(s.data(), s.size());
3008 #else
3009   return PyBytes_FromStringAndSize(s.data(), s.size());
3010 #endif
3011 }
3012 
3013 bool CheckResourceVariable(PyObject* item) {
3014   if (tensorflow::swig::IsResourceVariable(item)) {
3015     tensorflow::Safe_PyObjectPtr handle(
3016         PyObject_GetAttrString(item, "_handle"));
3017     return EagerTensor_CheckExact(handle.get());
3018   }
3019 
3020   return false;
3021 }
3022 
3023 bool IsNumberType(PyObject* item) {
3024 #if PY_MAJOR_VERSION >= 3
3025   return PyFloat_Check(item) || PyLong_Check(item);
3026 #else
3027   return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
3028 #endif
3029 }
3030 
3031 bool CheckOneInput(PyObject* item) {
3032   if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
3033       PyArray_Check(item) || IsNumberType(item)) {
3034     return true;
3035   }
3036 
3037   // Sequences are not properly handled. Sequences with purely python numeric
3038   // types work, but sequences with mixes of EagerTensors and python numeric
3039   // types don't work.
3040   // TODO(nareshmodi): fix
3041   return false;
3042 }
3043 
3044 bool CheckInputsOk(PyObject* seq, int start_index,
3045                    const tensorflow::OpDef& op_def) {
3046   for (int i = 0; i < op_def.input_arg_size(); i++) {
3047     PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
3048     if (!op_def.input_arg(i).number_attr().empty() ||
3049         !op_def.input_arg(i).type_list_attr().empty()) {
3050       // This item should be a seq input.
3051       if (!PySequence_Check(item)) {
3052         VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3053                 << "\", Input \"" << op_def.input_arg(i).name()
3054                 << "\" since we expected a sequence, but got "
3055                 << item->ob_type->tp_name;
3056         return false;
3057       }
3058       tensorflow::Safe_PyObjectPtr fast_item(
3059           PySequence_Fast(item, "Could not parse sequence."));
3060       if (fast_item.get() == nullptr) {
3061         return false;
3062       }
3063       int len = PySequence_Fast_GET_SIZE(fast_item.get());
3064       PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3065       for (Py_ssize_t j = 0; j < len; j++) {
3066         PyObject* inner_item = fast_item_array[j];
3067         if (!CheckOneInput(inner_item)) {
3068           VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3069                   << "\", Input \"" << op_def.input_arg(i).name()
3070                   << "\", Index " << j
3071                   << " since we expected an EagerTensor/ResourceVariable, "
3072                      "but got "
3073                   << inner_item->ob_type->tp_name;
3074           return false;
3075         }
3076       }
3077     } else if (!CheckOneInput(item)) {
3078       VLOG(1)
3079           << "Falling back to slow path for Op \"" << op_def.name()
3080           << "\", Input \"" << op_def.input_arg(i).name()
3081           << "\" since we expected an EagerTensor/ResourceVariable, but got "
3082           << item->ob_type->tp_name;
3083       return false;
3084     }
3085   }
3086 
3087   return true;
3088 }
3089 
3090 tensorflow::DataType MaybeGetDType(PyObject* item) {
3091   if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
3092     return tensorflow::PyTensor_DataType(item);
3093   }
3094 
3095   return tensorflow::DT_INVALID;
3096 }
3097 
3098 tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
3099                                           FastPathOpExecInfo* op_exec_info) {
3100   auto cached_it = op_exec_info->cached_dtypes.find(attr);
3101   if (cached_it != op_exec_info->cached_dtypes.end()) {
3102     return cached_it->second;
3103   }
3104 
3105   auto it = op_exec_info->attr_to_inputs_map->find(attr);
3106   if (it == op_exec_info->attr_to_inputs_map->end()) {
3107     // No other inputs - this should never happen.
3108     return tensorflow::DT_INVALID;
3109   }
3110 
3111   for (const auto& input_info : it->second) {
3112     PyObject* item = PyTuple_GET_ITEM(
3113         op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
3114     if (input_info.is_list) {
3115       tensorflow::Safe_PyObjectPtr fast_item(
3116           PySequence_Fast(item, "Unable to allocate"));
3117       int len = PySequence_Fast_GET_SIZE(fast_item.get());
3118       PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3119       for (int i = 0; i < len; i++) {
3120         auto dtype = MaybeGetDType(fast_item_array[i]);
3121         if (dtype != tensorflow::DT_INVALID) return dtype;
3122       }
3123     } else {
3124       auto dtype = MaybeGetDType(item);
3125       if (dtype != tensorflow::DT_INVALID) return dtype;
3126     }
3127   }
3128 
3129   auto default_it = op_exec_info->default_dtypes->find(attr);
3130   if (default_it != op_exec_info->default_dtypes->end()) {
3131     return default_it->second;
3132   }
3133 
3134   return tensorflow::DT_INVALID;
3135 }
3136 
3137 PyObject* CopySequenceSettingIndicesToNull(
3138     PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
3139   tensorflow::Safe_PyObjectPtr fast_seq(
3140       PySequence_Fast(seq, "unable to allocate"));
3141   int len = PySequence_Fast_GET_SIZE(fast_seq.get());
3142   PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
3143   PyObject* result = PyTuple_New(len);
3144   for (int i = 0; i < len; i++) {
3145     PyObject* item;
3146     if (indices.find(i) != indices.end()) {
3147       item = Py_None;
3148     } else {
3149       item = fast_seq_array[i];
3150     }
3151     Py_INCREF(item);
3152     PyTuple_SET_ITEM(result, i, item);
3153   }
3154   return result;
3155 }
3156 
3157 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
3158                          PyObject* results,
3159                          PyObject* forward_pass_name_scope = nullptr) {
3160   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
3161   if (PyErr_Occurred()) return nullptr;
3162   std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
3163   if (PyErr_Occurred()) return nullptr;
3164 
3165   bool should_record = false;
3166   for (TFE_Py_Tape* tape : SafeTapeSet()) {
3167     if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
3168       should_record = true;
3169       break;
3170     }
3171   }
3172   if (!should_record) {
3173     for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
3174       if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
3175         should_record = true;
3176         break;
3177       }
3178     }
3179   }
3180   if (!should_record) Py_RETURN_NONE;
3181 
3182   string c_op_name = TFE_GetPythonString(op_name);
3183 
3184   PyObject* op_outputs;
3185   bool op_outputs_tuple_created = false;
3186 
3187   if (const auto unused_output_indices =
3188           OpGradientUnusedOutputIndices(c_op_name)) {
3189     if (unused_output_indices->empty()) {
3190       op_outputs = Py_None;
3191     } else {
3192       op_outputs_tuple_created = true;
3193       op_outputs =
3194           CopySequenceSettingIndicesToNull(results, *unused_output_indices);
3195     }
3196   } else {
3197     op_outputs = results;
3198   }
3199 
3200   PyObject* op_inputs;
3201   bool op_inputs_tuple_created = false;
3202 
3203   if (const auto unused_input_indices =
3204           OpGradientUnusedInputIndices(c_op_name)) {
3205     if (unused_input_indices->empty()) {
3206       op_inputs = Py_None;
3207     } else {
3208       op_inputs_tuple_created = true;
3209       op_inputs =
3210           CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
3211     }
3212   } else {
3213     op_inputs = inputs;
3214   }
3215 
3216   tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
3217       [op_name, attrs, inputs, results](
3218           const std::vector<PyObject*>& input_tangents,
3219           std::vector<PyObject*>* output_tangents, bool use_batch) {
3220         return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
3221                                output_tangents, use_batch);
3222       });
3223   tensorflow::eager::ForwardFunction<PyObject>* forward_function;
3224   if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
3225       c_op_name == "If" || c_op_name == "StatelessIf") {
3226     // Control flow contains non-hashable attributes. Handling them in Python is
3227     // a headache, so instead we'll stay as close to GradientTape's handling as
3228     // possible (a null forward function means the accumulator forwards to a
3229     // tape).
3230     //
3231     // This is safe to do since we'll only see control flow when graph building,
3232     // in which case we can rely on pruning.
3233     forward_function = nullptr;
3234   } else {
3235     forward_function = &py_forward_function;
3236   }
3237 
3238   PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
3239 
3240   if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
3241 
3242   TapeSetRecordOperation(
3243       op_name, inputs, results, input_ids, input_dtypes,
3244       [op_name, attrs, num_inputs, op_inputs, op_outputs,
3245        forward_pass_name_scope]() {
3246         Py_INCREF(op_name);
3247         Py_INCREF(attrs);
3248         Py_INCREF(num_inputs);
3249         Py_INCREF(op_inputs);
3250         Py_INCREF(op_outputs);
3251         Py_INCREF(forward_pass_name_scope);
3252         PyBackwardFunction* function = new PyBackwardFunction(
3253             [op_name, attrs, num_inputs, op_inputs, op_outputs,
3254              forward_pass_name_scope](
3255                 PyObject* output_grads,
3256                 const std::vector<tensorflow::int64>& unneeded_gradients) {
3257               if (PyErr_Occurred()) {
3258                 return static_cast<PyObject*>(nullptr);
3259               }
3260               tensorflow::Safe_PyObjectPtr skip_input_indices;
3261               if (!unneeded_gradients.empty()) {
3262                 skip_input_indices.reset(
3263                     PyTuple_New(unneeded_gradients.size()));
3264                 for (int i = 0; i < unneeded_gradients.size(); i++) {
3265                   PyTuple_SET_ITEM(
3266                       skip_input_indices.get(), i,
3267                       GetPythonObjectFromInt(unneeded_gradients[i]));
3268                 }
3269               } else {
3270                 Py_INCREF(Py_None);
3271                 skip_input_indices.reset(Py_None);
3272               }
3273               tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3274                   "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3275                   output_grads, skip_input_indices.get(),
3276                   forward_pass_name_scope));
3277 
3278               tensorflow::Safe_PyObjectPtr result(
3279                   PyObject_CallObject(gradient_function, callback_args.get()));
3280 
3281               if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
3282 
3283               return tensorflow::swig::Flatten(result.get());
3284             });
3285         return function;
3286       },
3287       [op_name, attrs, num_inputs, op_inputs, op_outputs,
3288        forward_pass_name_scope](PyBackwardFunction* backward_function) {
3289         Py_DECREF(op_name);
3290         Py_DECREF(attrs);
3291         Py_DECREF(num_inputs);
3292         Py_DECREF(op_inputs);
3293         Py_DECREF(op_outputs);
3294         Py_DECREF(forward_pass_name_scope);
3295 
3296         delete backward_function;
3297       },
3298       forward_function);
3299 
3300   Py_DECREF(num_inputs);
3301   if (op_outputs_tuple_created) Py_DECREF(op_outputs);
3302   if (op_inputs_tuple_created) Py_DECREF(op_inputs);
3303 
3304   if (PyErr_Occurred()) {
3305     return nullptr;
3306   }
3307 
3308   Py_RETURN_NONE;
3309 }
3310 
3311 void MaybeNotifyVariableAccessed(PyObject* input) {
3312   DCHECK(CheckResourceVariable(input));
3313   DCHECK(PyObject_HasAttrString(input, "_trainable"));
3314 
3315   tensorflow::Safe_PyObjectPtr trainable(
3316       PyObject_GetAttrString(input, "_trainable"));
3317   if (trainable.get() == Py_False) return;
3318   TFE_Py_TapeVariableAccessed(input);
3319   TFE_Py_VariableWatcherVariableAccessed(input);
3320 }
3321 
3322 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
3323                     PyObject* input, tensorflow::Safe_PyObjectPtr* output,
3324                     TF_Status* status) {
3325   MaybeNotifyVariableAccessed(input);
3326 
3327   TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
3328   auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
3329   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3330 
3331   TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
3332   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3333 
3334   // Set dtype
3335   DCHECK(PyObject_HasAttrString(input, "_dtype"));
3336   tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
3337   int value;
3338   if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
3339     return false;
3340   }
3341   TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
3342 
3343   // Get handle
3344   tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
3345   if (!EagerTensor_CheckExact(handle.get())) return false;
3346   TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
3347   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3348 
3349   int num_retvals = 1;
3350   TFE_TensorHandle* output_handle;
3351   TFE_Execute(op, &output_handle, &num_retvals, status);
3352   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3353 
3354   // Always create the py object (and correctly DECREF it) from the returned
3355   // value, else the data will leak.
3356   output->reset(EagerTensorFromHandle(output_handle));
3357 
3358   // TODO(nareshmodi): Should we run post exec callbacks here?
3359   if (parent_op_exec_info.run_gradient_callback) {
3360     tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
3361     PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
3362 
3363     tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
3364     Py_INCREF(output->get());  // stay alive after since tuple steals.
3365     PyTuple_SET_ITEM(outputs.get(), 0, output->get());
3366 
3367     tensorflow::Safe_PyObjectPtr op_string(
3368         GetPythonObjectFromString("ReadVariableOp"));
3369     if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
3370                         outputs.get())) {
3371       return false;
3372     }
3373   }
3374 
3375   return true;
3376 }
3377 
3378 // Supports 3 cases at the moment:
3379 //  i) input is an EagerTensor.
3380 //  ii) input is a ResourceVariable - in this case, the is_variable param is
3381 //  set to true.
3382 //  iii) input is an arbitrary python list/tuple (note, this handling doesn't
3383 //  support packing).
3384 //
3385 //  NOTE: dtype_hint_getter must *always* return a PyObject that can be
3386 //  decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
3387 //  increfs Py_None).
3388 //
3389 //  NOTE: This function sets a python error directly, and returns false.
3390 //  TF_Status is only passed since we don't want to have to reallocate it.
3391 bool ConvertToTensor(
3392     const FastPathOpExecInfo& op_exec_info, PyObject* input,
3393     tensorflow::Safe_PyObjectPtr* output_handle,
3394     // This gets a hint for this particular input.
3395     const std::function<tensorflow::DataType()>& dtype_hint_getter,
3396     // This sets the dtype after conversion is complete.
3397     const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
3398     TF_Status* status) {
3399   if (EagerTensor_CheckExact(input)) {
3400     Py_INCREF(input);
3401     output_handle->reset(input);
3402     return true;
3403   } else if (CheckResourceVariable(input)) {
3404     return ReadVariableOp(op_exec_info, input, output_handle, status);
3405   }
3406 
3407   // The hint comes from a supposedly similarly typed tensor.
3408   tensorflow::DataType dtype_hint = dtype_hint_getter();
3409 
3410   TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
3411       op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
3412   if (handle == nullptr) {
3413     return MaybeRaiseExceptionFromTFStatus(status, nullptr);
3414   }
3415 
3416   output_handle->reset(EagerTensorFromHandle(handle));
3417   dtype_setter(
3418       static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
3419 
3420   return true;
3421 }
3422 
3423 // Adds input and type attr to the op, and to the list of flattened
3424 // inputs/attrs.
3425 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
3426                   const bool add_type_attr,
3427                   const tensorflow::OpDef::ArgDef& input_arg,
3428                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
3429                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
3430                   TFE_Op* op, TF_Status* status) {
3431   // py_eager_tensor's ownership is transferred to flattened_inputs if it is
3432   // required, else the object is destroyed and DECREF'd when the object goes
3433   // out of scope in this function.
3434   tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
3435 
3436   if (!ConvertToTensor(
3437           *op_exec_info, input, &py_eager_tensor,
3438           [&]() {
3439             if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
3440               return input_arg.type();
3441             }
3442             return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
3443           },
3444           [&](const tensorflow::DataType dtype) {
3445             op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
3446           },
3447           status)) {
3448     return false;
3449   }
3450 
3451   TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
3452 
3453   if (add_type_attr && !input_arg.type_attr().empty()) {
3454     auto dtype = TFE_TensorHandleDataType(input_handle);
3455     TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
3456     if (flattened_attrs != nullptr) {
3457       flattened_attrs->emplace_back(
3458           GetPythonObjectFromString(input_arg.type_attr()));
3459       flattened_attrs->emplace_back(PyLong_FromLong(dtype));
3460     }
3461   }
3462 
3463   if (flattened_inputs != nullptr) {
3464     flattened_inputs->emplace_back(std::move(py_eager_tensor));
3465   }
3466 
3467   TFE_OpAddInput(op, input_handle, status);
3468   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3469     return false;
3470   }
3471 
3472   return true;
3473 }
3474 
3475 const char* GetDeviceName(PyObject* py_device_name) {
3476   if (py_device_name != Py_None) {
3477     return TFE_GetPythonString(py_device_name);
3478   }
3479   return nullptr;
3480 }
3481 
3482 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
3483   if (!PySequence_Check(seq)) {
3484     PyErr_SetString(PyExc_TypeError,
3485                     Printf("expected a sequence for attr %s, got %s instead",
3486                            attr_name.data(), seq->ob_type->tp_name)
3487                         .data());
3488 
3489     return false;
3490   }
3491   if (PyArray_Check(seq) &&
3492       PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
3493     PyErr_SetString(PyExc_ValueError,
3494                     Printf("expected a sequence for attr %s, got an ndarray "
3495                            "with rank %d instead",
3496                            attr_name.data(),
3497                            PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
3498                         .data());
3499     return false;
3500   }
3501   return true;
3502 }
3503 
3504 bool RunCallbacks(
3505     const FastPathOpExecInfo& op_exec_info, PyObject* args,
3506     int num_inferred_attrs,
3507     const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
3508     const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
3509     PyObject* flattened_result) {
3510   DCHECK(op_exec_info.run_callbacks);
3511 
3512   tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
3513   for (int i = 0; i < flattened_inputs.size(); i++) {
3514     PyObject* input = flattened_inputs[i].get();
3515     Py_INCREF(input);
3516     PyTuple_SET_ITEM(inputs.get(), i, input);
3517   }
3518 
3519   int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
3520   int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
3521   tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
3522 
3523   for (int i = 0; i < num_non_inferred_attrs; i++) {
3524     auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
3525     Py_INCREF(attr);
3526     PyTuple_SET_ITEM(attrs.get(), i, attr);
3527   }
3528 
3529   for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
3530     PyObject* attr_or_name =
3531         flattened_attrs.at(i - num_non_inferred_attrs).get();
3532     Py_INCREF(attr_or_name);
3533     PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
3534   }
3535 
3536   if (op_exec_info.run_gradient_callback) {
3537     if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
3538                         flattened_result)) {
3539       return false;
3540     }
3541   }
3542 
3543   if (op_exec_info.run_post_exec_callbacks) {
3544     tensorflow::Safe_PyObjectPtr callback_args(
3545         Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
3546                       flattened_result, op_exec_info.name));
3547     for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
3548       PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
3549       if (!PyCallable_Check(callback_fn)) {
3550         PyErr_SetString(
3551             PyExc_TypeError,
3552             Printf("expected a function for "
3553                    "post execution callback in index %ld, got %s instead",
3554                    i, callback_fn->ob_type->tp_name)
3555                 .c_str());
3556         return false;
3557       }
3558       PyObject* callback_result =
3559           PyObject_CallObject(callback_fn, callback_args.get());
3560       if (!callback_result) {
3561         return false;
3562       }
3563       Py_DECREF(callback_result);
3564     }
3565   }
3566 
3567   return true;
3568 }
3569 
3570 }  // namespace
3571 
3572 PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
3573   tensorflow::profiler::TraceMe activity(
3574       "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
3575   Py_ssize_t args_size = PyTuple_GET_SIZE(args);
3576   if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
3577     PyErr_SetString(
3578         PyExc_ValueError,
3579         Printf("There must be at least %d items in the input tuple.",
3580                FAST_PATH_EXECUTE_ARG_INPUT_START)
3581             .c_str());
3582     return nullptr;
3583   }
3584 
3585   FastPathOpExecInfo op_exec_info;
3586 
3587   PyObject* py_eager_context =
3588       PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
3589 
3590   // TODO(edoper): Use interned string here
3591   PyObject* eager_context_handle =
3592       PyObject_GetAttrString(py_eager_context, "_context_handle");
3593 
3594   TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
3595       PyCapsule_GetPointer(eager_context_handle, nullptr));
3596   op_exec_info.ctx = ctx;
3597   op_exec_info.args = args;
3598 
3599   if (ctx == nullptr) {
3600     // The context hasn't been initialized. It will be in the slow path.
3601     RaiseFallbackException(
3602         "This function does not handle the case of the path where "
3603         "all inputs are not already EagerTensors.");
3604     return nullptr;
3605   }
3606 
3607   auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
3608   if (tld == nullptr) {
3609     return nullptr;
3610   }
3611   op_exec_info.device_name = GetDeviceName(tld->device_name.get());
3612   op_exec_info.callbacks = tld->op_callbacks.get();
3613 
3614   op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
3615   op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
3616 
3617   // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
3618   // (similar to benchmark_tf_gradient_function_*). Also consider using an
3619   // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
3620   // point out problems with heap allocs.
3621   op_exec_info.run_gradient_callback =
3622       !*ThreadTapeIsStopped() && HasAccumulatorOrTape();
3623   op_exec_info.run_post_exec_callbacks =
3624       op_exec_info.callbacks != Py_None &&
3625       PyList_Size(op_exec_info.callbacks) > 0;
3626   op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
3627                                op_exec_info.run_post_exec_callbacks;
3628 
3629   TF_Status* status = GetStatus();
3630   const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
3631   if (op_name == nullptr) {
3632     PyErr_SetString(PyExc_TypeError,
3633                     Printf("expected a string for op_name, got %s instead",
3634                            op_exec_info.op_name->ob_type->tp_name)
3635                         .c_str());
3636     return nullptr;
3637   }
3638 
3639   TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
3640 
3641   auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
3642     ReturnStatus(status);
3643     ReturnOp(ctx, op);
3644   });
3645 
3646   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3647     return nullptr;
3648   }
3649 
3650   tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
3651       tensorflow::StackTrace::kStackTraceInitialSize));
3652 
3653   const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
3654   if (op_def == nullptr) return nullptr;
3655 
3656   if (args_size <
3657       FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
3658     PyErr_SetString(
3659         PyExc_ValueError,
3660         Printf("Tuple size smaller than intended. Expected to be at least %d, "
3661                "was %ld",
3662                FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3663                args_size)
3664             .c_str());
3665     return nullptr;
3666   }
3667 
3668   if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
3669     RaiseFallbackException(
3670         "This function does not handle the case of the path where "
3671         "all inputs are not already EagerTensors.");
3672     return nullptr;
3673   }
3674 
3675   op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
3676   op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
3677 
3678   // Mapping of attr name to size - used to calculate the number of values
3679   // to be expected by the TFE_Execute run.
3680   tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
3681 
3682   // Set non-inferred attrs, including setting defaults if the attr is passed in
3683   // as None.
3684   for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
3685        i < args_size; i += 2) {
3686     PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
3687     const char* attr_name = TFE_GetPythonString(py_attr_name);
3688     PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
3689 
3690     // Not creating an index since most of the time there are not more than a
3691     // few attrs.
3692     // TODO(nareshmodi): Maybe include the index as part of the
3693     // OpRegistrationData.
3694     for (const auto& attr : op_def->attr()) {
3695       if (tensorflow::StringPiece(attr_name) == attr.name()) {
3696         SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
3697                               &attr_list_sizes, status);
3698 
3699         if (!status->status.ok()) {
3700           VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
3701                   << "\" since we are unable to set the value for attr \""
3702                   << attr.name() << "\" due to: " << TF_Message(status);
3703           RaiseFallbackException(TF_Message(status));
3704           return nullptr;
3705         }
3706 
3707         break;
3708       }
3709     }
3710   }
3711 
3712   // Flat attrs and inputs as required by the record_gradient call. The attrs
3713   // here only contain inferred attrs (non-inferred attrs are added directly
3714   // from the input args).
3715   // All items in flattened_attrs and flattened_inputs contain
3716   // Safe_PyObjectPtr - any time something steals a reference to this, it must
3717   // INCREF.
3718   // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
3719   // directly.
3720   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
3721       nullptr;
3722   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
3723       nullptr;
3724 
3725   // TODO(nareshmodi): Encapsulate callbacks information into a struct.
3726   if (op_exec_info.run_callbacks) {
3727     flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3728     flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3729   }
3730 
3731   // Add inferred attrs and inputs.
3732   // The following code might set duplicate type attrs. This will result in
3733   // the CacheKey for the generated AttrBuilder possibly differing from
3734   // those where the type attrs are correctly set. Inconsistent CacheKeys
3735   // for ops means that there might be unnecessarily duplicated kernels.
3736   // TODO(nareshmodi): Fix this.
3737   for (int i = 0; i < op_def->input_arg_size(); i++) {
3738     const auto& input_arg = op_def->input_arg(i);
3739 
3740     PyObject* input =
3741         PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
3742     if (!input_arg.number_attr().empty()) {
3743       // The item is a homogeneous list.
3744       if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
3745       tensorflow::Safe_PyObjectPtr fast_input(
3746           PySequence_Fast(input, "Could not parse sequence."));
3747       if (fast_input.get() == nullptr) {
3748         return nullptr;
3749       }
3750       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3751       PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3752 
3753       TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
3754       if (op_exec_info.run_callbacks) {
3755         flattened_attrs->emplace_back(
3756             GetPythonObjectFromString(input_arg.number_attr()));
3757         flattened_attrs->emplace_back(PyLong_FromLong(len));
3758       }
3759       attr_list_sizes[input_arg.number_attr()] = len;
3760 
3761       if (len > 0) {
3762         // First item adds the type attr.
3763         if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
3764                           flattened_attrs.get(), flattened_inputs.get(), op,
3765                           status)) {
3766           return nullptr;
3767         }
3768 
3769         for (Py_ssize_t j = 1; j < len; j++) {
3770           // Since the list is homogeneous, we don't need to re-add the attr.
3771           if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
3772                             input_arg, nullptr /* flattened_attrs */,
3773                             flattened_inputs.get(), op, status)) {
3774             return nullptr;
3775           }
3776         }
3777       }
3778     } else if (!input_arg.type_list_attr().empty()) {
3779       // The item is a heterogeneous list.
3780       if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
3781         return nullptr;
3782       }
3783       tensorflow::Safe_PyObjectPtr fast_input(
3784           PySequence_Fast(input, "Could not parse sequence."));
3785       if (fast_input.get() == nullptr) {
3786         return nullptr;
3787       }
3788       const string& attr_name = input_arg.type_list_attr();
3789       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3790       PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3791       tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
3792       PyObject* py_attr_value = nullptr;
3793       if (op_exec_info.run_callbacks) {
3794         py_attr_value = PyTuple_New(len);
3795       }
3796       for (Py_ssize_t j = 0; j < len; j++) {
3797         PyObject* py_input = fast_input_array[j];
3798         tensorflow::Safe_PyObjectPtr py_eager_tensor;
3799         if (!ConvertToTensor(
3800                 op_exec_info, py_input, &py_eager_tensor,
3801                 []() { return tensorflow::DT_INVALID; },
3802                 [](const tensorflow::DataType dtype) {}, status)) {
3803           return nullptr;
3804         }
3805 
3806         TFE_TensorHandle* input_handle =
3807             EagerTensor_Handle(py_eager_tensor.get());
3808 
3809         attr_value[j] = TFE_TensorHandleDataType(input_handle);
3810 
3811         TFE_OpAddInput(op, input_handle, status);
3812         if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3813           return nullptr;
3814         }
3815 
3816         if (op_exec_info.run_callbacks) {
3817           flattened_inputs->emplace_back(std::move(py_eager_tensor));
3818 
3819           PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
3820         }
3821       }
3822       if (op_exec_info.run_callbacks) {
3823         flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
3824         flattened_attrs->emplace_back(py_attr_value);
3825       }
3826       TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
3827                             attr_value.size());
3828       attr_list_sizes[attr_name] = len;
3829     } else {
3830       // The item is a single item.
3831       if (!AddInputToOp(&op_exec_info, input, true, input_arg,
3832                         flattened_attrs.get(), flattened_inputs.get(), op,
3833                         status)) {
3834         return nullptr;
3835       }
3836     }
3837   }
3838 
3839   int64_t num_outputs = 0;
3840   for (int i = 0; i < op_def->output_arg_size(); i++) {
3841     const auto& output_arg = op_def->output_arg(i);
3842     int64_t delta = 1;
3843     if (!output_arg.number_attr().empty()) {
3844       delta = attr_list_sizes[output_arg.number_attr()];
3845     } else if (!output_arg.type_list_attr().empty()) {
3846       delta = attr_list_sizes[output_arg.type_list_attr()];
3847     }
3848     if (delta < 0) {
3849       RaiseFallbackException(
3850           "Attributes suggest that the size of an output list is less than 0");
3851       return nullptr;
3852     }
3853     num_outputs += delta;
3854   }
3855 
3856   // If number of retvals is larger than int32, we error out.
3857   if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
3858     PyErr_SetString(
3859         PyExc_ValueError,
3860         Printf("Number of outputs is too big: %ld", num_outputs).c_str());
3861     return nullptr;
3862   }
3863   int num_retvals = num_outputs;
3864 
3865   tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
3866 
3867   Py_BEGIN_ALLOW_THREADS;
3868   TFE_Execute(op, retvals.data(), &num_retvals, status);
3869   Py_END_ALLOW_THREADS;
3870 
3871   if (!status->status.ok()) {
3872     // Augment the status with the op_name for easier debugging similar to
3873     // TFE_Py_Execute.
3874     std::vector<tensorflow::StackFrame> stack_trace =
3875         status->status.stack_trace();
3876     status->status = tensorflow::Status(
3877         status->status.code(),
3878         tensorflow::strings::StrCat(
3879             TF_Message(status),
3880             " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]"),
3881         std::move(stack_trace));
3882 
3883     MaybeRaiseExceptionFromTFStatus(status, nullptr);
3884     return nullptr;
3885   }
3886 
3887   tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
3888   for (int i = 0; i < num_retvals; ++i) {
3889     PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
3890   }
3891 
3892   if (op_exec_info.run_callbacks) {
3893     if (!RunCallbacks(
3894             op_exec_info, args,
3895             FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3896             *flattened_inputs, *flattened_attrs, flat_result.get())) {
3897       return nullptr;
3898     }
3899   }
3900 
3901   // Unflatten results.
3902   if (op_def->output_arg_size() == 0) {
3903     Py_RETURN_NONE;
3904   }
3905 
3906   if (op_def->output_arg_size() == 1) {
3907     if (!op_def->output_arg(0).number_attr().empty() ||
3908         !op_def->output_arg(0).type_list_attr().empty()) {
3909       return flat_result.release();
3910     } else {
3911       auto* result = PyList_GET_ITEM(flat_result.get(), 0);
3912       Py_INCREF(result);
3913       return result;
3914     }
3915   }
3916 
3917   // Correctly output the results that are made into a namedtuple.
3918   PyObject* result = PyList_New(op_def->output_arg_size());
3919   int flat_result_index = 0;
3920   for (int i = 0; i < op_def->output_arg_size(); i++) {
3921     if (!op_def->output_arg(i).number_attr().empty()) {
3922       int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
3923       PyObject* inner_list = PyList_New(list_length);
3924       for (int j = 0; j < list_length; j++) {
3925         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3926         Py_INCREF(obj);
3927         PyList_SET_ITEM(inner_list, j, obj);
3928       }
3929       PyList_SET_ITEM(result, i, inner_list);
3930     } else if (!op_def->output_arg(i).type_list_attr().empty()) {
3931       int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
3932       PyObject* inner_list = PyList_New(list_length);
3933       for (int j = 0; j < list_length; j++) {
3934         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3935         Py_INCREF(obj);
3936         PyList_SET_ITEM(inner_list, j, obj);
3937       }
3938       PyList_SET_ITEM(result, i, inner_list);
3939     } else {
3940       PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3941       Py_INCREF(obj);
3942       PyList_SET_ITEM(result, i, obj);
3943     }
3944   }
3945   return result;
3946 }
3947 
3948 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
3949                                 PyObject* attrs, PyObject* results,
3950                                 PyObject* forward_pass_name_scope) {
3951   if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
3952     Py_RETURN_NONE;
3953   }
3954 
3955   return RecordGradient(op_name, inputs, attrs, results,
3956                         forward_pass_name_scope);
3957 }
3958 
3959 namespace {
3960 const char kTensor[] = "T";
3961 const char kList[] = "L";
3962 const char kListEnd[] = "l";
3963 const char kTuple[] = "U";
3964 const char kTupleEnd[] = "u";
3965 const char kDIter[] = "I";
3966 const char kDict[] = "D";
3967 const char kRaw[] = "R";
3968 const char kShape[] = "s";
3969 const char kShapeDelim[] = "-";
3970 const char kDType[] = "d";
3971 const char kNone[] = "n";
3972 const char kCompositeTensor[] = "C";
3973 const char kAttrs[] = "A";
3974 const char kAttrsEnd[] = "a";
3975 
3976 struct EncodeResult {
3977   string str;
3978   std::vector<PyObject*> objects;
3979 
3980   PyObject* ToPyTuple() {
3981     PyObject* result = PyTuple_New(2);
3982 
3983     PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str));
3984 
3985     if (objects.empty()) {
3986       Py_INCREF(Py_None);
3987       PyTuple_SET_ITEM(result, 1, Py_None);
3988     } else {
3989       PyObject* objects_tuple = PyTuple_New(objects.size());
3990 
3991       for (int i = 0; i < objects.size(); i++) {
3992         PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
3993       }
3994 
3995       PyTuple_SET_ITEM(result, 1, objects_tuple);
3996     }
3997 
3998     return result;
3999   }
4000 };
4001 
4002 tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
4003                                        bool include_tensor_ranks_only,
4004                                        EncodeResult* result) {
4005   if (EagerTensor_CheckExact(arg)) {
4006     tensorflow::ImmediateExecutionTensorHandle* handle =
4007         tensorflow::unwrap(EagerTensor_Handle(arg));
4008 
4009     absl::StrAppend(&result->str, kDType,
4010                     static_cast<tensorflow::DataType>(handle->DataType()));
4011     absl::StrAppend(&result->str, kShape);
4012 
4013     int num_dims;
4014     tensorflow::Status status = handle->NumDims(&num_dims);
4015     if (!status.ok()) return status;
4016 
4017     if (include_tensor_ranks_only) {
4018       absl::StrAppend(&result->str, num_dims);
4019     } else {
4020       for (int i = 0; i < num_dims; ++i) {
4021         int64_t dim_size;
4022         status = handle->Dim(i, &dim_size);
4023         if (!status.ok()) return status;
4024         absl::StrAppend(&result->str, dim_size, kShapeDelim);
4025       }
4026     }
4027     return tensorflow::Status::OK();
4028   }
4029 
4030   tensorflow::Safe_PyObjectPtr dtype_object(
4031       PyObject_GetAttrString(arg, "dtype"));
4032 
4033   if (dtype_object == nullptr) {
4034     return tensorflow::errors::InvalidArgument(
4035         "ops.Tensor object doesn't have dtype() attr.");
4036   }
4037 
4038   tensorflow::Safe_PyObjectPtr dtype_enum(
4039       PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
4040 
4041   if (dtype_enum == nullptr) {
4042     return tensorflow::errors::InvalidArgument(
4043         "ops.Tensor's dtype object doesn't have _type_enum() attr.");
4044   }
4045 
4046   tensorflow::DataType dtype =
4047       static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
4048 
4049   absl::StrAppend(&result->str, kDType, dtype);
4050 
4051   static char _shape_tuple[] = "_shape_tuple";
4052   tensorflow::Safe_PyObjectPtr shape_tuple(
4053       PyObject_CallMethod(arg, _shape_tuple, nullptr));
4054 
4055   if (shape_tuple == nullptr) {
4056     return tensorflow::errors::InvalidArgument(
4057         "ops.Tensor object doesn't have _shape_tuple() method.");
4058   }
4059 
4060   if (shape_tuple.get() == Py_None) {
4061     // Unknown shape, encode that directly.
4062     absl::StrAppend(&result->str, kNone);
4063     return tensorflow::Status::OK();
4064   }
4065 
4066   absl::StrAppend(&result->str, kShape);
4067   tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
4068       shape_tuple.get(), "shape_tuple didn't return a sequence"));
4069 
4070   int len = PySequence_Fast_GET_SIZE(shape_seq.get());
4071   PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get());
4072 
4073   if (include_tensor_ranks_only) {
4074     absl::StrAppend(&result->str, len);
4075   } else {
4076     for (int i = 0; i < len; ++i) {
4077       PyObject* item = shape_seq_array[i];
4078       if (item == Py_None) {
4079         absl::StrAppend(&result->str, kNone);
4080       } else {
4081         absl::StrAppend(&result->str, MakeInt(item));
4082       }
4083     }
4084   }
4085   return tensorflow::Status::OK();
4086 }
4087 
4088 tensorflow::Status TFE_Py_EncodeArgHelperInternal(
4089     PyObject* arg, bool include_tensor_ranks_only, std::vector<int>& res_vec,
4090     absl::flat_hash_map<int, int>& res_map, int& cur_res, EncodeResult* result);
4091 
4092 // This function doesn't set the type of sequence before
4093 tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
4094                                          const char* end_type,
4095                                          bool include_tensor_ranks_only,
4096                                          std::vector<int>& res_vec,
4097                                          absl::flat_hash_map<int, int>& res_map,
4098                                          int& cur_res, EncodeResult* result) {
4099   tensorflow::Safe_PyObjectPtr arg_seq(
4100       PySequence_Fast(arg, "unable to create seq from list/tuple"));
4101 
4102   absl::StrAppend(&result->str, type);
4103   int len = PySequence_Fast_GET_SIZE(arg_seq.get());
4104   PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get());
4105   for (int i = 0; i < len; ++i) {
4106     PyObject* item = arg_seq_array[i];
4107     if (item == Py_None) {
4108       absl::StrAppend(&result->str, kNone);
4109     } else {
4110       TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelperInternal(
4111           item, include_tensor_ranks_only, res_vec, res_map, cur_res, result));
4112     }
4113   }
4114   absl::StrAppend(&result->str, end_type);
4115 
4116   return tensorflow::Status::OK();
4117 }
4118 
4119 void UpdateResourceCount(int res_id, std::vector<int>& res_vec,
4120                          absl::flat_hash_map<int, int>& res_map, int& cur_res) {
4121   const auto& it = res_map.find(res_id);
4122   if (it == res_map.end()) {
4123     res_map[res_id] = cur_res;
4124     res_vec.push_back(cur_res);
4125     ++cur_res;
4126   } else {
4127     res_vec.push_back(it->second);
4128   }
4129 }
4130 
4131 tensorflow::Status TFE_Py_EncodeArgHelperInternal(
4132     PyObject* arg, bool include_tensor_ranks_only, std::vector<int>& res_vec,
4133     absl::flat_hash_map<int, int>& res_map, int& cur_res,
4134     EncodeResult* result) {
4135   if (tensorflow::swig::IsTensor(arg)) {
4136     absl::StrAppend(&result->str, kTensor);
4137     TF_RETURN_IF_ERROR(
4138         TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
4139   } else if (tensorflow::swig::IsOwnedIterator(arg)) {
4140     // TODO(jiaweix): distinguish other resource types
4141     // Similar to IsCompositeTensor below, plus resource id
4142     PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
4143     if (type_spec == nullptr) {
4144       return tensorflow::errors::InvalidArgument(
4145           "Error while reading OwnedIterator._type_spec.");
4146     }
4147     result->objects.push_back(type_spec);
4148 
4149     // Add resource tracking
4150     tensorflow::Safe_PyObjectPtr itr_res(
4151         PyObject_GetAttrString(arg, "_iterator_resource"));
4152     if (itr_res == nullptr) {
4153       return tensorflow::errors::InvalidArgument(
4154           "Error while reading Dataset iterator resource.");
4155     }
4156     // OwnedIterator does not always have a unique resource id,
4157     // because a Dataset object is not required for OwnedIterator.__init__.
4158     // As a result we check whether '_iterator_resource' is a Tensor.
4159     if (tensorflow::swig::IsTensor(itr_res.get())) {
4160       absl::StrAppend(&result->str, kDIter);
4161       tensorflow::Safe_PyObjectPtr p_res_id(
4162           PyObject_GetAttrString(itr_res.get(), "_id"));
4163       if (p_res_id == nullptr) {
4164         return tensorflow::errors::InvalidArgument(
4165             "Error while reading Dataset iterator resouce id.");
4166       }
4167       int res_id = PyLong_AsSize_t(p_res_id.get());
4168       if (res_id < 0) {
4169         return tensorflow::errors::InvalidArgument("PyLong_AsSize_t failure");
4170       }
4171       UpdateResourceCount(res_id, res_vec, res_map, cur_res);
4172     } else {
4173       // If '_iterator_resource' is not a Tensor, there is no resource id.
4174       // Instead we treat it the same way as a CompositeTensor
4175       absl::StrAppend(&result->str, kCompositeTensor);
4176     }
4177   } else if (PyList_Check(arg)) {
4178     TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, kListEnd,
4179                                              include_tensor_ranks_only, res_vec,
4180                                              res_map, cur_res, result));
4181   } else if (tensorflow::swig::IsTuple(arg)) {
4182     TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, kTupleEnd,
4183                                              include_tensor_ranks_only, res_vec,
4184                                              res_map, cur_res, result));
4185   } else if (tensorflow::swig::IsMapping(arg)) {
4186     tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
4187     if (PyList_Sort(keys.get()) == -1) {
4188       return tensorflow::errors::Internal("Unable to sort keys");
4189     }
4190 
4191     absl::StrAppend(&result->str, kDict);
4192     int len = PyList_Size(keys.get());
4193 
4194     for (int i = 0; i < len; i++) {
4195       PyObject* key = PyList_GetItem(keys.get(), i);
4196       TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelperInternal(
4197           key, include_tensor_ranks_only, res_vec, res_map, cur_res, result));
4198       tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
4199       TF_RETURN_IF_ERROR(
4200           TFE_Py_EncodeArgHelperInternal(value.get(), include_tensor_ranks_only,
4201                                          res_vec, res_map, cur_res, result));
4202     }
4203   } else if (tensorflow::swig::IsCompositeTensor(arg)) {
4204     absl::StrAppend(&result->str, kCompositeTensor);
4205 
4206     // Add the typespec to the list of objects.  (Do *not* use a weakref,
4207     // since the type spec is often a temporary object.)
4208     PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
4209     if (type_spec == nullptr) {
4210       return tensorflow::errors::InvalidArgument(
4211           "Error while reading CompositeTensor._type_spec.");
4212     }
4213     result->objects.push_back(type_spec);
4214   } else if (tensorflow::swig::IsTypeSpec(arg)) {
4215     // Add the typespec (not a weakref) in case it's a temporary object.
4216     absl::StrAppend(&result->str, kRaw);
4217     Py_INCREF(arg);
4218     result->objects.push_back(arg);
4219   } else if (tensorflow::swig::IsAttrs(arg)) {
4220     absl::StrAppend(&result->str, kAttrs);
4221     tensorflow::Safe_PyObjectPtr attrs(
4222         PyObject_GetAttrString(arg, "__attrs_attrs__"));
4223     tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get()));
4224     for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item;
4225          item.reset(PyIter_Next(iter.get()))) {
4226       tensorflow::Safe_PyObjectPtr name(
4227           PyObject_GetAttrString(item.get(), "name"));
4228       tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get()));
4229       TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelperInternal(
4230           attr_arg.get(), include_tensor_ranks_only, res_vec, res_map, cur_res,
4231           result));
4232     }
4233     absl::StrAppend(&result->str, kAttrsEnd);
4234   } else {
4235     PyObject* object = PyWeakref_NewRef(arg, nullptr);
4236 
4237     if (object == nullptr) {
4238       PyErr_Clear();
4239 
4240       object = arg;
4241       Py_INCREF(object);
4242     }
4243 
4244     absl::StrAppend(&result->str, kRaw);
4245     result->objects.push_back(object);
4246   }
4247 
4248   return tensorflow::Status::OK();
4249 }
4250 
4251 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
4252                                           bool include_tensor_ranks_only,
4253                                           EncodeResult* result) {
4254   std::vector<int> res_vec;
4255   absl::flat_hash_map<int, int> res_map;
4256   int cur_res = 0;
4257   auto status = TFE_Py_EncodeArgHelperInternal(
4258       arg, include_tensor_ranks_only, res_vec, res_map, cur_res, result);
4259 
4260   // Add 'encoding' of resources
4261   std::string str_resource_encoding = "";
4262   for (auto&& i : res_vec) {
4263     str_resource_encoding.append(std::to_string(i));
4264     str_resource_encoding.append("_");
4265   }
4266   if (!str_resource_encoding.empty()) {
4267     result->objects.push_back(
4268         PyUnicode_FromString(str_resource_encoding.c_str()));
4269   }
4270 
4271   return status;
4272 }
4273 
4274 }  // namespace
4275 
4276 // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
4277 // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
4278 // are used for both performance reasons, as much TensorFlow code specializes
4279 // on known shapes to produce slimmer graphs, and correctness, as some
4280 // high-level APIs require shapes to be fully-known.
4281 //
4282 // `include_tensor_ranks_only` allows caching on arguments excluding shape info,
4283 // so that a slow path using relaxed shape can rely on a cache key that excludes
4284 // shapes.
4285 PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
4286   EncodeResult result;
4287   const auto status =
4288       TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
4289   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
4290     return nullptr;
4291   }
4292 
4293   return result.ToPyTuple();
4294 }
4295 
4296 // A method prints incoming messages directly to Python's
4297 // stdout using Python's C API. This is necessary in Jupyter notebooks
4298 // and colabs where messages to the C stdout don't go to the notebook
4299 // cell outputs, but calls to Python's stdout do.
4300 void PrintToPythonStdout(const char* msg) {
4301   if (Py_IsInitialized()) {
4302     PyGILState_STATE py_threadstate;
4303     py_threadstate = PyGILState_Ensure();
4304 
4305     string string_msg = msg;
4306     // PySys_WriteStdout truncates strings over 1000 bytes, so
4307     // we write the message in chunks small enough to not be truncated.
4308     int CHUNK_SIZE = 900;
4309     auto len = string_msg.length();
4310     for (int i = 0; i < len; i += CHUNK_SIZE) {
4311       PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
4312     }
4313 
4314     // Force flushing to make sure print newlines aren't interleaved in
4315     // some colab environments
4316     PyRun_SimpleString("import sys; sys.stdout.flush()");
4317 
4318     PyGILState_Release(py_threadstate);
4319   }
4320 }
4321 
4322 // Register PrintToPythonStdout as a log listener, to allow
4323 // printing in colabs and jupyter notebooks to work.
4324 void TFE_Py_EnableInteractivePythonLogging() {
4325   static bool enabled_interactive_logging = false;
4326   if (!enabled_interactive_logging) {
4327     enabled_interactive_logging = true;
4328     TF_RegisterLogListener(PrintToPythonStdout);
4329   }
4330 }
4331 
4332 namespace {
4333 // weak reference to Python Context object currently active
4334 PyObject* weak_eager_context = nullptr;
4335 }  // namespace
4336 
4337 PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
4338   Py_XDECREF(weak_eager_context);
4339   weak_eager_context = PyWeakref_NewRef(py_context, nullptr);
4340   if (weak_eager_context == nullptr) {
4341     return nullptr;
4342   }
4343   Py_RETURN_NONE;
4344 }
4345 
4346 PyObject* GetPyEagerContext() {
4347   if (weak_eager_context == nullptr) {
4348     PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
4349     return nullptr;
4350   }
4351   PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context);
4352   if (py_context == Py_None) {
4353     PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed");
4354     return nullptr;
4355   }
4356   Py_INCREF(py_context);
4357   return py_context;
4358 }
4359 
4360 namespace {
4361 
4362 // Default values for thread_local_data fields.
4363 struct EagerContextThreadLocalDataDefaults {
4364   tensorflow::Safe_PyObjectPtr is_eager;
4365   tensorflow::Safe_PyObjectPtr device_spec;
4366 };
4367 
4368 // Maps each py_eager_context object to its thread_local_data.
4369 //
4370 // Note: we need to use the python Context object as the key here (and not
4371 // its handle object), because the handle object isn't created until the
4372 // context is initialized; but thread_local_data is potentially accessed
4373 // before then.
4374 using EagerContextThreadLocalDataMap = absl::flat_hash_map<
4375     PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
4376 thread_local EagerContextThreadLocalDataMap*
4377     eager_context_thread_local_data_map = nullptr;
4378 
4379 // Maps each py_eager_context object to default values.
4380 using EagerContextThreadLocalDataDefaultsMap =
4381     absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
4382 EagerContextThreadLocalDataDefaultsMap*
4383     eager_context_thread_local_data_defaults = nullptr;
4384 
4385 }  // namespace
4386 
4387 namespace tensorflow {
4388 
4389 void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
4390                                      PyObject* is_eager,
4391                                      PyObject* device_spec) {
4392   DCheckPyGilState();
4393   if (eager_context_thread_local_data_defaults == nullptr) {
4394     absl::LeakCheckDisabler disabler;
4395     eager_context_thread_local_data_defaults =
4396         new EagerContextThreadLocalDataDefaultsMap();
4397   }
4398   if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
4399     PyErr_SetString(PyExc_AssertionError,
4400                     "MakeEagerContextThreadLocalData may not be called "
4401                     "twice on the same eager Context object.");
4402   }
4403 
4404   auto& defaults =
4405       (*eager_context_thread_local_data_defaults)[py_eager_context];
4406   Py_INCREF(is_eager);
4407   defaults.is_eager.reset(is_eager);
4408   Py_INCREF(device_spec);
4409   defaults.device_spec.reset(device_spec);
4410 }
4411 
4412 EagerContextThreadLocalData* GetEagerContextThreadLocalData(
4413     PyObject* py_eager_context) {
4414   if (eager_context_thread_local_data_defaults == nullptr) {
4415     PyErr_SetString(PyExc_AssertionError,
4416                     "MakeEagerContextThreadLocalData must be called "
4417                     "before GetEagerContextThreadLocalData.");
4418     return nullptr;
4419   }
4420   auto defaults =
4421       eager_context_thread_local_data_defaults->find(py_eager_context);
4422   if (defaults == eager_context_thread_local_data_defaults->end()) {
4423     PyErr_SetString(PyExc_AssertionError,
4424                     "MakeEagerContextThreadLocalData must be called "
4425                     "before GetEagerContextThreadLocalData.");
4426     return nullptr;
4427   }
4428 
4429   if (eager_context_thread_local_data_map == nullptr) {
4430     absl::LeakCheckDisabler disabler;
4431     eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
4432   }
4433   auto& thread_local_data =
4434       (*eager_context_thread_local_data_map)[py_eager_context];
4435 
4436   if (!thread_local_data) {
4437     thread_local_data.reset(new EagerContextThreadLocalData());
4438 
4439     Safe_PyObjectPtr is_eager(
4440         PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr));
4441     if (!is_eager) return nullptr;
4442     thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
4443 
4444 #if PY_MAJOR_VERSION >= 3
4445     PyObject* scope_name = PyUnicode_FromString("");
4446 #else
4447     PyObject* scope_name = PyString_FromString("");
4448 #endif
4449     thread_local_data->scope_name.reset(scope_name);
4450 
4451 #if PY_MAJOR_VERSION >= 3
4452     PyObject* device_name = PyUnicode_FromString("");
4453 #else
4454     PyObject* device_name = PyString_FromString("");
4455 #endif
4456     thread_local_data->device_name.reset(device_name);
4457 
4458     Py_INCREF(defaults->second.device_spec.get());
4459     thread_local_data->device_spec.reset(defaults->second.device_spec.get());
4460 
4461     Py_INCREF(Py_None);
4462     thread_local_data->function_call_options.reset(Py_None);
4463 
4464     Py_INCREF(Py_None);
4465     thread_local_data->executor.reset(Py_None);
4466 
4467     thread_local_data->op_callbacks.reset(PyList_New(0));
4468   }
4469   return thread_local_data.get();
4470 }
4471 
4472 void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
4473   DCheckPyGilState();
4474   if (eager_context_thread_local_data_defaults) {
4475     eager_context_thread_local_data_defaults->erase(py_eager_context);
4476   }
4477   if (eager_context_thread_local_data_map) {
4478     eager_context_thread_local_data_map->erase(py_eager_context);
4479   }
4480 }
4481 
4482 }  // namespace tensorflow
4483