• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 #include <cstring>
18 #include <unordered_map>
19 
20 #include "absl/debugging/leak_check.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/c/c_api.h"
24 #include "tensorflow/c/c_api_internal.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/c/eager/tape.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_op_internal.h"
30 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
31 #include "tensorflow/c/tf_status.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/gtl/compactptrset.h"
36 #include "tensorflow/core/lib/gtl/flatmap.h"
37 #include "tensorflow/core/lib/gtl/flatset.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/lib/strings/stringprintf.h"
40 #include "tensorflow/core/platform/casts.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/status.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/util/managed_stack_trace.h"
47 #include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
48 #include "tensorflow/python/eager/pywrap_tensor.h"
49 #include "tensorflow/python/eager/pywrap_tfe.h"
50 #include "tensorflow/python/lib/core/py_util.h"
51 #include "tensorflow/python/lib/core/safe_ptr.h"
52 #include "tensorflow/python/util/stack_trace.h"
53 #include "tensorflow/python/util/util.h"
54 
55 using tensorflow::Status;
56 using tensorflow::string;
57 using tensorflow::strings::Printf;
58 
59 namespace {
60 // NOTE: Items are retrieved from and returned to these unique_ptrs, and they
61 // act as arenas. This is important if the same thread requests 2 items without
62 // releasing one.
63 // The following sequence of events on the same thread will still succeed:
64 // - GetOp <- Returns existing.
65 // - GetOp <- Allocates and returns a new pointer.
66 // - ReleaseOp <- Sets the item in the unique_ptr.
67 // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
68 // This occurs when a PyFunc kernel is run. This behavior makes it safe in that
69 // case, as well as the case where python decides to reuse the underlying
70 // C++ thread in 2 python threads case.
71 struct OpDeleter {
operator ()__anon12da81bc0111::OpDeleter72   void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
73 };
74 thread_local std::unordered_map<TFE_Context*,
75                                 std::unique_ptr<TFE_Op, OpDeleter>>
76     thread_local_eager_operation_map;                             // NOLINT
77 thread_local std::unique_ptr<TF_Status> thread_local_tf_status =  // NOLINT
78     nullptr;
79 
ReleaseThreadLocalOp(TFE_Context * ctx)80 std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
81   auto it = thread_local_eager_operation_map.find(ctx);
82   if (it == thread_local_eager_operation_map.end()) {
83     return nullptr;
84   }
85   return std::move(it->second);
86 }
87 
GetOp(TFE_Context * ctx,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)88 TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
89               const char* raw_device_name, TF_Status* status) {
90   auto op = ReleaseThreadLocalOp(ctx);
91   if (!op) {
92     op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
93   }
94   status->status =
95       tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
96   if (!status->status.ok()) {
97     op.reset();
98   }
99   return op.release();
100 }
101 
ReturnOp(TFE_Context * ctx,TFE_Op * op)102 void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
103   if (op) {
104     tensorflow::unwrap(op)->Clear();
105     thread_local_eager_operation_map[ctx].reset(op);
106   }
107 }
108 
ReleaseThreadLocalStatus()109 TF_Status* ReleaseThreadLocalStatus() {
110   if (thread_local_tf_status == nullptr) {
111     return nullptr;
112   }
113   return thread_local_tf_status.release();
114 }
115 
116 struct InputInfo {
InputInfo__anon12da81bc0111::InputInfo117   InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
118 
119   int i;
120   bool is_list = false;
121 };
122 
123 // Takes in output gradients, returns input gradients.
124 typedef std::function<PyObject*(PyObject*,
125                                 const std::vector<tensorflow::int64>&)>
126     PyBackwardFunction;
127 
128 using AttrToInputsMap =
129     tensorflow::gtl::FlatMap<string,
130                              tensorflow::gtl::InlinedVector<InputInfo, 4>>;
131 
GetAllAttrToInputsMaps()132 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
133   static auto* all_attr_to_input_maps =
134       new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
135   return all_attr_to_input_maps;
136 }
137 
138 // This function doesn't use a lock, since we depend on the GIL directly.
GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef & op_def)139 AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
140 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
141   DCHECK(PyGILState_Check())
142       << "This function needs to hold the GIL when called.";
143 #endif
144   auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
145   auto* output =
146       tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
147   if (output != nullptr) {
148     return output;
149   }
150 
151   std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
152 
153   // Store a list of InputIndex -> List of corresponding inputs.
154   for (int i = 0; i < op_def.input_arg_size(); i++) {
155     if (!op_def.input_arg(i).type_attr().empty()) {
156       auto it = m->find(op_def.input_arg(i).type_attr());
157       if (it == m->end()) {
158         it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
159       }
160       it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
161     }
162   }
163 
164   auto* retval = m.get();
165   (*all_attr_to_input_maps)[op_def.name()] = m.release();
166 
167   return retval;
168 }
169 
170 // This function doesn't use a lock, since we depend on the GIL directly.
171 tensorflow::gtl::FlatMap<
172     string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
GetAllAttrToDefaultsMaps()173 GetAllAttrToDefaultsMaps() {
174   static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
175       string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
176   return all_attr_to_defaults_maps;
177 }
178 
179 tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef & op_def)180 GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
181 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
182   DCHECK(PyGILState_Check())
183       << "This function needs to hold the GIL when called.";
184 #endif
185   auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
186   auto* output =
187       tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
188   if (output != nullptr) {
189     return output;
190   }
191 
192   auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
193 
194   for (const auto& attr : op_def.attr()) {
195     if (attr.type() == "type" && attr.has_default_value()) {
196       new_map->insert({attr.name(), attr.default_value().type()});
197     }
198   }
199 
200   (*all_attr_to_defaults_maps)[op_def.name()] = new_map;
201 
202   return new_map;
203 }
204 
205 struct FastPathOpExecInfo {
206   TFE_Context* ctx;
207   const char* device_name;
208 
209   bool run_callbacks;
210   bool run_post_exec_callbacks;
211   bool run_gradient_callback;
212 
213   // The op name of the main op being executed.
214   PyObject* name;
215   // The op type name of the main op being executed.
216   PyObject* op_name;
217   PyObject* callbacks;
218 
219   // All the args passed into the FastPathOpExecInfo.
220   PyObject* args;
221 
222   // DTypes can come from another input that has the same attr. So build that
223   // map.
224   const AttrToInputsMap* attr_to_inputs_map;
225   const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
226   tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
227 };
228 
229 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn)                       \
230   bool fn_name(const string& key, PyObject* py_value, TF_Status* status,     \
231                type* value) {                                                \
232     if (check_fn(py_value)) {                                                \
233       *value = static_cast<type>(parse_fn(py_value));                        \
234       return true;                                                           \
235     } else {                                                                 \
236       TF_SetStatus(status, TF_INVALID_ARGUMENT,                              \
237                    tensorflow::strings::StrCat(                              \
238                        "Expecting " #type " value for attr ", key, ", got ", \
239                        py_value->ob_type->tp_name)                           \
240                        .c_str());                                            \
241       return false;                                                          \
242     }                                                                        \
243   }
244 
245 #if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue,int,PyLong_Check,PyLong_AsLong)246 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
247 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
248 #else
249 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
250 #endif
251 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
252 #undef PARSE_VALUE
253 
254 #if PY_MAJOR_VERSION < 3
255 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
256                      int64_t* value) {
257   if (PyInt_Check(py_value)) {
258     *value = static_cast<int64_t>(PyInt_AsLong(py_value));
259     return true;
260   } else if (PyLong_Check(py_value)) {
261     *value = static_cast<int64_t>(PyLong_AsLong(py_value));
262     return true;
263   }
264   TF_SetStatus(
265       status, TF_INVALID_ARGUMENT,
266       tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
267                                   ", got ", py_value->ob_type->tp_name)
268           .c_str());
269   return false;
270 }
271 #endif
272 
TensorShapeNumDims(PyObject * value)273 Py_ssize_t TensorShapeNumDims(PyObject* value) {
274   const auto size = PySequence_Size(value);
275   if (size == -1) {
276     // TensorShape.__len__ raises an error in the scenario where the shape is an
277     // unknown, which needs to be cleared.
278     // TODO(nareshmodi): ensure that this is actually a TensorShape.
279     PyErr_Clear();
280   }
281   return size;
282 }
283 
IsInteger(PyObject * py_value)284 bool IsInteger(PyObject* py_value) {
285 #if PY_MAJOR_VERSION >= 3
286   return PyLong_Check(py_value);
287 #else
288   return PyInt_Check(py_value) || PyLong_Check(py_value);
289 #endif
290 }
291 
292 // This function considers a Dimension._value of None to be valid, and sets the
293 // value to be -1 in that case.
ParseDimensionValue(const string & key,PyObject * py_value,TF_Status * status,int64_t * value)294 bool ParseDimensionValue(const string& key, PyObject* py_value,
295                          TF_Status* status, int64_t* value) {
296   if (IsInteger(py_value)) {
297     return ParseInt64Value(key, py_value, status, value);
298   }
299 
300   tensorflow::Safe_PyObjectPtr dimension_value(
301       PyObject_GetAttrString(py_value, "_value"));
302   if (dimension_value == nullptr) {
303     PyErr_Clear();
304     TF_SetStatus(
305         status, TF_INVALID_ARGUMENT,
306         tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
307                                     ", got ", py_value->ob_type->tp_name)
308             .c_str());
309     return false;
310   }
311 
312   if (dimension_value.get() == Py_None) {
313     *value = -1;
314     return true;
315   }
316 
317   return ParseInt64Value(key, dimension_value.get(), status, value);
318 }
319 
ParseStringValue(const string & key,PyObject * py_value,TF_Status * status,tensorflow::StringPiece * value)320 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
321                       tensorflow::StringPiece* value) {
322   if (PyBytes_Check(py_value)) {
323     Py_ssize_t size = 0;
324     char* buf = nullptr;
325     if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
326     *value = tensorflow::StringPiece(buf, size);
327     return true;
328   }
329 #if PY_MAJOR_VERSION >= 3
330   if (PyUnicode_Check(py_value)) {
331     Py_ssize_t size = 0;
332     const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
333     if (buf == nullptr) return false;
334     *value = tensorflow::StringPiece(buf, size);
335     return true;
336   }
337 #endif
338   TF_SetStatus(
339       status, TF_INVALID_ARGUMENT,
340       tensorflow::strings::StrCat("Expecting a string value for attr ", key,
341                                   ", got ", py_value->ob_type->tp_name)
342           .c_str());
343   return false;
344 }
345 
ParseBoolValue(const string & key,PyObject * py_value,TF_Status * status,unsigned char * value)346 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
347                     unsigned char* value) {
348   *value = PyObject_IsTrue(py_value);
349   return true;
350 }
351 
352 // The passed in py_value is expected to be an object of the python type
353 // dtypes.DType or an int.
ParseTypeValue(const string & key,PyObject * py_value,TF_Status * status,int * value)354 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
355                     int* value) {
356   if (IsInteger(py_value)) {
357     return ParseIntValue(key, py_value, status, value);
358   }
359 
360   tensorflow::Safe_PyObjectPtr py_type_enum(
361       PyObject_GetAttrString(py_value, "_type_enum"));
362   if (py_type_enum == nullptr) {
363     PyErr_Clear();
364     TF_SetStatus(
365         status, TF_INVALID_ARGUMENT,
366         tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
367                                     ", got ", py_value->ob_type->tp_name)
368             .c_str());
369     return false;
370   }
371 
372   return ParseIntValue(key, py_type_enum.get(), status, value);
373 }
374 
SetOpAttrList(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_list,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)375 bool SetOpAttrList(
376     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
377     TF_AttrType type,
378     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
379     TF_Status* status) {
380   if (!PySequence_Check(py_list)) {
381     TF_SetStatus(
382         status, TF_INVALID_ARGUMENT,
383         tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
384                                     ", got ", py_list->ob_type->tp_name)
385             .c_str());
386     return false;
387   }
388   const int num_values = PySequence_Size(py_list);
389   if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
390 
391 #define PARSE_LIST(c_type, parse_fn)                                      \
392   std::unique_ptr<c_type[]> values(new c_type[num_values]);               \
393   for (int i = 0; i < num_values; ++i) {                                  \
394     tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));   \
395     if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
396   }
397 
398   if (type == TF_ATTR_STRING) {
399     std::unique_ptr<const void*[]> values(new const void*[num_values]);
400     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
401     for (int i = 0; i < num_values; ++i) {
402       tensorflow::StringPiece value;
403       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
404       if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
405       values[i] = value.data();
406       lengths[i] = value.size();
407     }
408     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
409   } else if (type == TF_ATTR_INT) {
410     PARSE_LIST(int64_t, ParseInt64Value);
411     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
412   } else if (type == TF_ATTR_FLOAT) {
413     PARSE_LIST(float, ParseFloatValue);
414     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
415   } else if (type == TF_ATTR_BOOL) {
416     PARSE_LIST(unsigned char, ParseBoolValue);
417     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
418   } else if (type == TF_ATTR_TYPE) {
419     PARSE_LIST(int, ParseTypeValue);
420     TFE_OpSetAttrTypeList(op, key,
421                           reinterpret_cast<const TF_DataType*>(values.get()),
422                           num_values);
423   } else if (type == TF_ATTR_SHAPE) {
424     // Make one pass through the input counting the total number of
425     // dims across all the input lists.
426     int total_dims = 0;
427     for (int i = 0; i < num_values; ++i) {
428       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
429       if (py_value.get() != Py_None) {
430         if (!PySequence_Check(py_value.get())) {
431           TF_SetStatus(
432               status, TF_INVALID_ARGUMENT,
433               tensorflow::strings::StrCat(
434                   "Expecting None or sequence value for element", i,
435                   " of attr ", key, ", got ", py_value->ob_type->tp_name)
436                   .c_str());
437           return false;
438         }
439         const auto size = TensorShapeNumDims(py_value.get());
440         if (size >= 0) {
441           total_dims += size;
442         }
443       }
444     }
445     // Allocate a buffer that can fit all of the dims together.
446     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
447     // Copy the input dims into the buffer and set dims to point to
448     // the start of each list's dims.
449     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
450     std::unique_ptr<int[]> num_dims(new int[num_values]);
451     int64_t* offset = buffer.get();
452     for (int i = 0; i < num_values; ++i) {
453       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
454       if (py_value.get() == Py_None) {
455         dims[i] = nullptr;
456         num_dims[i] = -1;
457       } else {
458         const auto size = TensorShapeNumDims(py_value.get());
459         if (size == -1) {
460           dims[i] = nullptr;
461           num_dims[i] = -1;
462           continue;
463         }
464         dims[i] = offset;
465         num_dims[i] = size;
466         for (int j = 0; j < size; ++j) {
467           tensorflow::Safe_PyObjectPtr inner_py_value(
468               PySequence_ITEM(py_value.get(), j));
469           if (inner_py_value.get() == Py_None) {
470             *offset = -1;
471           } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
472                                           offset)) {
473             return false;
474           }
475           ++offset;
476         }
477       }
478     }
479     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
480                            status);
481     if (!status->status.ok()) return false;
482   } else if (type == TF_ATTR_FUNC) {
483     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
484     for (int i = 0; i < num_values; ++i) {
485       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
486       // Allow:
487       // (1) String function name, OR
488       // (2) A Python object with a .name attribute
489       //     (A crude test for being a
490       //     tensorflow.python.framework.function._DefinedFunction)
491       //     (which is what the various "defun" or "Defun" decorators do).
492       // And in the future also allow an object that can encapsulate
493       // the function name and its attribute values.
494       tensorflow::StringPiece func_name;
495       if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
496         PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
497         if (name_attr == nullptr ||
498             !ParseStringValue(key, name_attr, status, &func_name)) {
499           TF_SetStatus(
500               status, TF_INVALID_ARGUMENT,
501               tensorflow::strings::StrCat(
502                   "unable to set function value attribute from a ",
503                   py_value.get()->ob_type->tp_name,
504                   " object. If you think this is an error, please file an "
505                   "issue at "
506                   "https://github.com/tensorflow/tensorflow/issues/new")
507                   .c_str());
508           return false;
509         }
510       }
511       funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
512       if (!status->status.ok()) return false;
513     }
514     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
515     if (!status->status.ok()) return false;
516   } else {
517     TF_SetStatus(status, TF_UNIMPLEMENTED,
518                  tensorflow::strings::StrCat("Attr ", key,
519                                              " has unhandled list type ", type)
520                      .c_str());
521     return false;
522   }
523 #undef PARSE_LIST
524   return true;
525 }
526 
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)527 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
528                 TF_Status* status) {
529   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
530   for (const auto& attr : func.attr()) {
531     if (!status->status.ok()) return nullptr;
532     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
533     if (!status->status.ok()) return nullptr;
534   }
535   return func_op;
536 }
537 
SetOpAttrListDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * key,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)538 void SetOpAttrListDefault(
539     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
540     const char* key, TF_AttrType type,
541     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
542     TF_Status* status) {
543   if (type == TF_ATTR_STRING) {
544     int num_values = attr.default_value().list().s_size();
545     std::unique_ptr<const void*[]> values(new const void*[num_values]);
546     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
547     (*attr_list_sizes)[key] = num_values;
548     for (int i = 0; i < num_values; i++) {
549       const string& v = attr.default_value().list().s(i);
550       values[i] = v.data();
551       lengths[i] = v.size();
552     }
553     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
554   } else if (type == TF_ATTR_INT) {
555     int num_values = attr.default_value().list().i_size();
556     std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
557     (*attr_list_sizes)[key] = num_values;
558     for (int i = 0; i < num_values; i++) {
559       values[i] = attr.default_value().list().i(i);
560     }
561     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
562   } else if (type == TF_ATTR_FLOAT) {
563     int num_values = attr.default_value().list().f_size();
564     std::unique_ptr<float[]> values(new float[num_values]);
565     (*attr_list_sizes)[key] = num_values;
566     for (int i = 0; i < num_values; i++) {
567       values[i] = attr.default_value().list().f(i);
568     }
569     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
570   } else if (type == TF_ATTR_BOOL) {
571     int num_values = attr.default_value().list().b_size();
572     std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
573     (*attr_list_sizes)[key] = num_values;
574     for (int i = 0; i < num_values; i++) {
575       values[i] = attr.default_value().list().b(i);
576     }
577     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
578   } else if (type == TF_ATTR_TYPE) {
579     int num_values = attr.default_value().list().type_size();
580     std::unique_ptr<int[]> values(new int[num_values]);
581     (*attr_list_sizes)[key] = num_values;
582     for (int i = 0; i < num_values; i++) {
583       values[i] = attr.default_value().list().type(i);
584     }
585     TFE_OpSetAttrTypeList(op, key,
586                           reinterpret_cast<const TF_DataType*>(values.get()),
587                           attr.default_value().list().type_size());
588   } else if (type == TF_ATTR_SHAPE) {
589     int num_values = attr.default_value().list().shape_size();
590     (*attr_list_sizes)[key] = num_values;
591     int total_dims = 0;
592     for (int i = 0; i < num_values; ++i) {
593       if (!attr.default_value().list().shape(i).unknown_rank()) {
594         total_dims += attr.default_value().list().shape(i).dim_size();
595       }
596     }
597     // Allocate a buffer that can fit all of the dims together.
598     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
599     // Copy the input dims into the buffer and set dims to point to
600     // the start of each list's dims.
601     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
602     std::unique_ptr<int[]> num_dims(new int[num_values]);
603     int64_t* offset = buffer.get();
604     for (int i = 0; i < num_values; ++i) {
605       const auto& shape = attr.default_value().list().shape(i);
606       if (shape.unknown_rank()) {
607         dims[i] = nullptr;
608         num_dims[i] = -1;
609       } else {
610         for (int j = 0; j < shape.dim_size(); j++) {
611           *offset = shape.dim(j).size();
612           ++offset;
613         }
614       }
615     }
616     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
617                            status);
618   } else if (type == TF_ATTR_FUNC) {
619     int num_values = attr.default_value().list().func_size();
620     (*attr_list_sizes)[key] = num_values;
621     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
622     for (int i = 0; i < num_values; i++) {
623       funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
624     }
625     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
626   } else {
627     TF_SetStatus(status, TF_UNIMPLEMENTED,
628                  "Lists of tensors are not yet implemented for default valued "
629                  "attributes for an operation.");
630   }
631 }
632 
SetOpAttrScalar(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_value,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)633 bool SetOpAttrScalar(
634     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
635     TF_AttrType type,
636     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
637     TF_Status* status) {
638   if (type == TF_ATTR_STRING) {
639     tensorflow::StringPiece value;
640     if (!ParseStringValue(key, py_value, status, &value)) return false;
641     TFE_OpSetAttrString(op, key, value.data(), value.size());
642   } else if (type == TF_ATTR_INT) {
643     int64_t value;
644     if (!ParseInt64Value(key, py_value, status, &value)) return false;
645     TFE_OpSetAttrInt(op, key, value);
646     // attr_list_sizes is set for all int attributes (since at this point we are
647     // not aware if that attribute might be used to calculate the size of an
648     // output list or not).
649     if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
650   } else if (type == TF_ATTR_FLOAT) {
651     float value;
652     if (!ParseFloatValue(key, py_value, status, &value)) return false;
653     TFE_OpSetAttrFloat(op, key, value);
654   } else if (type == TF_ATTR_BOOL) {
655     unsigned char value;
656     if (!ParseBoolValue(key, py_value, status, &value)) return false;
657     TFE_OpSetAttrBool(op, key, value);
658   } else if (type == TF_ATTR_TYPE) {
659     int value;
660     if (!ParseTypeValue(key, py_value, status, &value)) return false;
661     TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
662   } else if (type == TF_ATTR_SHAPE) {
663     if (py_value == Py_None) {
664       TFE_OpSetAttrShape(op, key, nullptr, -1, status);
665     } else {
666       if (!PySequence_Check(py_value)) {
667         TF_SetStatus(status, TF_INVALID_ARGUMENT,
668                      tensorflow::strings::StrCat(
669                          "Expecting None or sequence value for attr", key,
670                          ", got ", py_value->ob_type->tp_name)
671                          .c_str());
672         return false;
673       }
674       const auto num_dims = TensorShapeNumDims(py_value);
675       if (num_dims == -1) {
676         TFE_OpSetAttrShape(op, key, nullptr, -1, status);
677         return true;
678       }
679       std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
680       for (int i = 0; i < num_dims; ++i) {
681         tensorflow::Safe_PyObjectPtr inner_py_value(
682             PySequence_ITEM(py_value, i));
683         if (inner_py_value.get() == Py_None) {
684           dims[i] = -1;
685         } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
686                                         &dims[i])) {
687           return false;
688         }
689       }
690       TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
691     }
692     if (!status->status.ok()) return false;
693   } else if (type == TF_ATTR_FUNC) {
694     // Allow:
695     // (1) String function name, OR
696     // (2) A Python object with a .name attribute
697     //     (A crude test for being a
698     //     tensorflow.python.framework.function._DefinedFunction)
699     //     (which is what the various "defun" or "Defun" decorators do).
700     // And in the future also allow an object that can encapsulate
701     // the function name and its attribute values.
702     tensorflow::StringPiece func_name;
703     if (!ParseStringValue(key, py_value, status, &func_name)) {
704       PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
705       if (name_attr == nullptr ||
706           !ParseStringValue(key, name_attr, status, &func_name)) {
707         TF_SetStatus(
708             status, TF_INVALID_ARGUMENT,
709             tensorflow::strings::StrCat(
710                 "unable to set function value attribute from a ",
711                 py_value->ob_type->tp_name,
712                 " object. If you think this is an error, please file an issue "
713                 "at https://github.com/tensorflow/tensorflow/issues/new")
714                 .c_str());
715         return false;
716       }
717     }
718     TF_SetStatus(status, TF_OK, "");
719     TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
720   } else {
721     TF_SetStatus(
722         status, TF_UNIMPLEMENTED,
723         tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
724             .c_str());
725     return false;
726   }
727   return true;
728 }
729 
SetOpAttrScalarDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)730 void SetOpAttrScalarDefault(
731     TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
732     const char* attr_name,
733     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
734     TF_Status* status) {
735   SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
736   if (default_value.value_case() == tensorflow::AttrValue::kI) {
737     (*attr_list_sizes)[attr_name] = default_value.i();
738   }
739 }
740 
741 // start_index is the index at which the Tuple/List attrs will start getting
742 // processed.
SetOpAttrs(TFE_Context * ctx,TFE_Op * op,PyObject * attrs,int start_index,TF_Status * out_status)743 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
744                 TF_Status* out_status) {
745   if (attrs == Py_None) return;
746   Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
747   if ((len & 1) != 0) {
748     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
749                  "Expecting attrs tuple to have even length.");
750     return;
751   }
752   // Parse attrs
753   for (Py_ssize_t i = 0; i < len; i += 2) {
754     PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
755     PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
756 #if PY_MAJOR_VERSION >= 3
757     const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
758                                             : PyUnicode_AsUTF8(py_key);
759 #else
760     const char* key = PyBytes_AsString(py_key);
761 #endif
762     unsigned char is_list = 0;
763     const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
764     if (!out_status->status.ok()) return;
765     if (is_list != 0) {
766       if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
767         return;
768     } else {
769       if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
770         return;
771     }
772   }
773 }
774 
775 // This function will set the op attrs required. If an attr has the value of
776 // None, then it will read the AttrDef to get the default value and set that
777 // instead. Any failure in this function will simply fall back to the slow
778 // path.
SetOpAttrWithDefaults(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * attr_name,PyObject * attr_value,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)779 void SetOpAttrWithDefaults(
780     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
781     const char* attr_name, PyObject* attr_value,
782     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
783     TF_Status* status) {
784   unsigned char is_list = 0;
785   const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
786   if (!status->status.ok()) return;
787   if (attr_value == Py_None) {
788     if (is_list != 0) {
789       SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
790                            status);
791     } else {
792       SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
793                              attr_list_sizes, status);
794     }
795   } else {
796     if (is_list != 0) {
797       SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
798                     status);
799     } else {
800       SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
801                       status);
802     }
803   }
804 }
805 
GetPythonObjectFromInt(int num)806 PyObject* GetPythonObjectFromInt(int num) {
807 #if PY_MAJOR_VERSION >= 3
808   return PyLong_FromLong(num);
809 #else
810   return PyInt_FromLong(num);
811 #endif
812 }
813 
814 // Python subclass of Exception that is created on not ok Status.
815 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
816 PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
817 
818 // Python subclass of Exception that is created to signal fallback.
819 PyObject* fallback_exception_class = nullptr;
820 
821 // Python function that returns input gradients given output gradients.
822 PyObject* gradient_function = nullptr;
823 
824 // Python function that returns output gradients given input gradients.
825 PyObject* forward_gradient_function = nullptr;
826 
827 static std::atomic<int64_t> _uid;
828 
829 }  // namespace
830 
GetStatus()831 TF_Status* GetStatus() {
832   TF_Status* maybe_status = ReleaseThreadLocalStatus();
833   if (maybe_status) {
834     TF_SetStatus(maybe_status, TF_OK, "");
835     return maybe_status;
836   } else {
837     return TF_NewStatus();
838   }
839 }
840 
ReturnStatus(TF_Status * status)841 void ReturnStatus(TF_Status* status) {
842   TF_SetStatus(status, TF_OK, "");
843   thread_local_tf_status.reset(status);
844 }
845 
TFE_Py_Execute(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_OutputTensorHandles * outputs,TF_Status * out_status)846 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
847                     const char* op_name, TFE_InputTensorHandles* inputs,
848                     PyObject* attrs, TFE_OutputTensorHandles* outputs,
849                     TF_Status* out_status) {
850   TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
851                            /*cancellation_manager=*/nullptr, outputs,
852                            out_status);
853 }
854 
TFE_Py_ExecuteCancelable(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_CancellationManager * cancellation_manager,TFE_OutputTensorHandles * outputs,TF_Status * out_status)855 void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
856                               const char* op_name,
857                               TFE_InputTensorHandles* inputs, PyObject* attrs,
858                               TFE_CancellationManager* cancellation_manager,
859                               TFE_OutputTensorHandles* outputs,
860                               TF_Status* out_status) {
861   tensorflow::profiler::TraceMe activity(
862       "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
863 
864   TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
865 
866   auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
867   if (!out_status->status.ok()) return;
868 
869   tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
870       tensorflow::StackTrace::kStackTraceInitialSize));
871 
872   for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
873     TFE_OpAddInput(op, inputs->at(i), out_status);
874   }
875   if (cancellation_manager && out_status->status.ok()) {
876     TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
877   }
878   if (out_status->status.ok()) {
879     SetOpAttrs(ctx, op, attrs, 0, out_status);
880   }
881   Py_BEGIN_ALLOW_THREADS;
882 
883   int num_outputs = outputs->size();
884 
885   if (out_status->status.ok()) {
886     TFE_Execute(op, outputs->data(), &num_outputs, out_status);
887   }
888 
889   if (out_status->status.ok()) {
890     outputs->resize(num_outputs);
891   } else {
892     TF_SetStatus(out_status, TF_GetCode(out_status),
893                  tensorflow::strings::StrCat(TF_Message(out_status),
894                                              " [Op:", op_name, "]")
895                      .c_str());
896   }
897 
898   Py_END_ALLOW_THREADS;
899 }
900 
TFE_Py_RegisterExceptionClass(PyObject * e)901 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
902   tensorflow::mutex_lock l(exception_class_mutex);
903   if (exception_class != nullptr) {
904     Py_DECREF(exception_class);
905   }
906   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
907     exception_class = nullptr;
908     PyErr_SetString(PyExc_TypeError,
909                     "TFE_Py_RegisterExceptionClass: "
910                     "Registered class should be subclass of Exception.");
911     return nullptr;
912   }
913 
914   Py_INCREF(e);
915   exception_class = e;
916   Py_RETURN_NONE;
917 }
918 
TFE_Py_RegisterFallbackExceptionClass(PyObject * e)919 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
920   if (fallback_exception_class != nullptr) {
921     Py_DECREF(fallback_exception_class);
922   }
923   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
924     fallback_exception_class = nullptr;
925     PyErr_SetString(PyExc_TypeError,
926                     "TFE_Py_RegisterFallbackExceptionClass: "
927                     "Registered class should be subclass of Exception.");
928     return nullptr;
929   } else {
930     Py_INCREF(e);
931     fallback_exception_class = e;
932     Py_RETURN_NONE;
933   }
934 }
935 
TFE_Py_RegisterGradientFunction(PyObject * e)936 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
937   if (gradient_function != nullptr) {
938     Py_DECREF(gradient_function);
939   }
940   if (!PyCallable_Check(e)) {
941     gradient_function = nullptr;
942     PyErr_SetString(PyExc_TypeError,
943                     "TFE_Py_RegisterGradientFunction: "
944                     "Registered object should be function.");
945     return nullptr;
946   } else {
947     Py_INCREF(e);
948     gradient_function = e;
949     Py_RETURN_NONE;
950   }
951 }
952 
TFE_Py_RegisterJVPFunction(PyObject * e)953 PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
954   if (forward_gradient_function != nullptr) {
955     Py_DECREF(forward_gradient_function);
956   }
957   if (!PyCallable_Check(e)) {
958     forward_gradient_function = nullptr;
959     PyErr_SetString(PyExc_TypeError,
960                     "TFE_Py_RegisterJVPFunction: "
961                     "Registered object should be function.");
962     return nullptr;
963   } else {
964     Py_INCREF(e);
965     forward_gradient_function = e;
966     Py_RETURN_NONE;
967   }
968 }
969 
RaiseFallbackException(const char * message)970 void RaiseFallbackException(const char* message) {
971   if (fallback_exception_class != nullptr) {
972     PyErr_SetString(fallback_exception_class, message);
973     return;
974   }
975 
976   PyErr_SetString(
977       PyExc_RuntimeError,
978       tensorflow::strings::StrCat(
979           "Fallback exception type not set, attempting to fallback due to ",
980           message)
981           .data());
982 }
983 
984 // Format and return `status`' error message with the attached stack trace if
985 // available. `status` must have an error.
FormatErrorStatusStackTrace(const tensorflow::Status & status)986 std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
987   tensorflow::DCheckPyGilState();
988   DCHECK(!status.ok());
989 
990   if (status.stack_trace().empty()) return status.error_message();
991 
992   const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace();
993 
994   PyObject* linecache = PyImport_ImportModule("linecache");
995   PyObject* getline =
996       PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
997   DCHECK(getline);
998 
999   std::ostringstream result;
1000   result << "Exception originated from\n\n";
1001 
1002   for (const tensorflow::StackFrame& stack_frame : stack_trace) {
1003     PyObject* line_str_obj = PyObject_CallFunction(
1004         getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
1005         stack_frame.line_number);
1006     tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
1007     tensorflow::str_util::RemoveWhitespaceContext(&line_str);
1008     result << "  File \"" << stack_frame.file_name << "\", line "
1009            << stack_frame.line_number << ", in " << stack_frame.function_name
1010            << '\n';
1011 
1012     if (!line_str.empty()) result << "    " << line_str << '\n';
1013     Py_XDECREF(line_str_obj);
1014   }
1015 
1016   Py_DecRef(getline);
1017   Py_DecRef(linecache);
1018 
1019   result << '\n' << status.error_message();
1020   return result.str();
1021 }
1022 
MaybeRaiseExceptionFromTFStatus(TF_Status * status,PyObject * exception)1023 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
1024   if (status->status.ok()) return 0;
1025   const char* msg = TF_Message(status);
1026   if (exception == nullptr) {
1027     tensorflow::mutex_lock l(exception_class_mutex);
1028     if (exception_class != nullptr) {
1029       tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1030           "si", FormatErrorStatusStackTrace(status->status).c_str(),
1031           TF_GetCode(status)));
1032       if (PyErr_Occurred()) {
1033         // NOTE: This hides the actual error (i.e. the reason `status` was not
1034         // TF_OK), but there is nothing we can do at this point since we can't
1035         // generate a reasonable error from the status.
1036         // Consider adding a message explaining this.
1037         return -1;
1038       }
1039       PyErr_SetObject(exception_class, val.get());
1040       return -1;
1041     } else {
1042       exception = PyExc_RuntimeError;
1043     }
1044   }
1045   // May be update already set exception.
1046   PyErr_SetString(exception, msg);
1047   return -1;
1048 }
1049 
MaybeRaiseExceptionFromStatus(const tensorflow::Status & status,PyObject * exception)1050 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
1051                                   PyObject* exception) {
1052   if (status.ok()) return 0;
1053   const char* msg = status.error_message().c_str();
1054   if (exception == nullptr) {
1055     tensorflow::mutex_lock l(exception_class_mutex);
1056     if (exception_class != nullptr) {
1057       tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1058           "si", FormatErrorStatusStackTrace(status).c_str(), status.code()));
1059       PyErr_SetObject(exception_class, val.get());
1060       return -1;
1061     } else {
1062       exception = PyExc_RuntimeError;
1063     }
1064   }
1065   // May be update already set exception.
1066   PyErr_SetString(exception, msg);
1067   return -1;
1068 }
1069 
TFE_GetPythonString(PyObject * o)1070 const char* TFE_GetPythonString(PyObject* o) {
1071 #if PY_MAJOR_VERSION >= 3
1072   if (PyBytes_Check(o)) {
1073     return PyBytes_AsString(o);
1074   } else {
1075     return PyUnicode_AsUTF8(o);
1076   }
1077 #else
1078   return PyBytes_AsString(o);
1079 #endif
1080 }
1081 
get_uid()1082 int64_t get_uid() { return _uid++; }
1083 
TFE_Py_UID()1084 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
1085 
TFE_DeleteContextCapsule(PyObject * context)1086 void TFE_DeleteContextCapsule(PyObject* context) {
1087   TFE_Context* ctx =
1088       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
1089   auto op = ReleaseThreadLocalOp(ctx);
1090   op.reset();
1091   TFE_DeleteContext(ctx);
1092 }
1093 
MakeInt(PyObject * integer)1094 static tensorflow::int64 MakeInt(PyObject* integer) {
1095 #if PY_MAJOR_VERSION >= 3
1096   return PyLong_AsLong(integer);
1097 #else
1098   return PyInt_AsLong(integer);
1099 #endif
1100 }
1101 
FastTensorId(PyObject * tensor)1102 static tensorflow::int64 FastTensorId(PyObject* tensor) {
1103   if (EagerTensor_CheckExact(tensor)) {
1104     return PyEagerTensor_ID(tensor);
1105   }
1106   PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
1107   if (id_field == nullptr) {
1108     return -1;
1109   }
1110   tensorflow::int64 id = MakeInt(id_field);
1111   Py_DECREF(id_field);
1112   return id;
1113 }
1114 
1115 namespace tensorflow {
PyTensor_DataType(PyObject * tensor)1116 DataType PyTensor_DataType(PyObject* tensor) {
1117   if (EagerTensor_CheckExact(tensor)) {
1118     return PyEagerTensor_Dtype(tensor);
1119   } else {
1120 #if PY_MAJOR_VERSION < 3
1121     // Python 2.x:
1122     static PyObject* dtype_attr = PyString_InternFromString("dtype");
1123     static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
1124 #else
1125     // Python 3.x:
1126     static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
1127     static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
1128 #endif
1129     Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
1130     if (!dtype_field) {
1131       return DT_INVALID;
1132     }
1133 
1134     Safe_PyObjectPtr enum_field(
1135         PyObject_GetAttr(dtype_field.get(), type_enum_attr));
1136     if (!enum_field) {
1137       return DT_INVALID;
1138     }
1139 
1140     return static_cast<DataType>(MakeInt(enum_field.get()));
1141   }
1142 }
1143 }  // namespace tensorflow
1144 
1145 class PyTapeTensor {
1146  public:
PyTapeTensor(tensorflow::int64 id,tensorflow::DataType dtype,const tensorflow::TensorShape & shape)1147   PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
1148                const tensorflow::TensorShape& shape)
1149       : id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(tensorflow::int64 id,tensorflow::DataType dtype,PyObject * shape)1150   PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
1151                PyObject* shape)
1152       : id_(id), dtype_(dtype), shape_(shape) {
1153     Py_INCREF(absl::get<1>(shape_));
1154   }
PyTapeTensor(const PyTapeTensor & other)1155   PyTapeTensor(const PyTapeTensor& other) {
1156     id_ = other.id_;
1157     dtype_ = other.dtype_;
1158     shape_ = other.shape_;
1159     if (shape_.index() == 1) {
1160       Py_INCREF(absl::get<1>(shape_));
1161     }
1162   }
1163 
~PyTapeTensor()1164   ~PyTapeTensor() {
1165     if (shape_.index() == 1) {
1166       Py_DECREF(absl::get<1>(shape_));
1167     }
1168   }
1169   PyObject* GetShape() const;
GetPyDType() const1170   PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
GetID() const1171   tensorflow::int64 GetID() const { return id_; }
GetDType() const1172   tensorflow::DataType GetDType() const { return dtype_; }
1173 
1174   PyObject* OnesLike() const;
1175   PyObject* ZerosLike() const;
1176 
1177  private:
1178   tensorflow::int64 id_;
1179   tensorflow::DataType dtype_;
1180 
1181   // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
1182   // PyObject is the tensor itself. This is used to support tf.shape(tensor) for
1183   // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
1184   // tensors.
1185   absl::variant<tensorflow::TensorShape, PyObject*> shape_;
1186 };
1187 
1188 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
1189 
1190 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
1191                                                   PyTapeTensor> {
1192  public:
PyVSpace(PyObject * py_vspace)1193   explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
1194     Py_INCREF(py_vspace_);
1195   }
1196 
Initialize()1197   tensorflow::Status Initialize() {
1198     num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
1199     if (num_elements_ == nullptr) {
1200       return tensorflow::errors::InvalidArgument("invalid vspace");
1201     }
1202     aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
1203     if (aggregate_fn_ == nullptr) {
1204       return tensorflow::errors::InvalidArgument("invalid vspace");
1205     }
1206     zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
1207     if (zeros_fn_ == nullptr) {
1208       return tensorflow::errors::InvalidArgument("invalid vspace");
1209     }
1210     zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
1211     if (zeros_like_fn_ == nullptr) {
1212       return tensorflow::errors::InvalidArgument("invalid vspace");
1213     }
1214     ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
1215     if (ones_fn_ == nullptr) {
1216       return tensorflow::errors::InvalidArgument("invalid vspace");
1217     }
1218     ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
1219     if (ones_like_fn_ == nullptr) {
1220       return tensorflow::errors::InvalidArgument("invalid vspace");
1221     }
1222     graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
1223     if (graph_shape_fn_ == nullptr) {
1224       return tensorflow::errors::InvalidArgument("invalid vspace");
1225     }
1226     return tensorflow::Status::OK();
1227   }
1228 
~PyVSpace()1229   ~PyVSpace() override {
1230     Py_XDECREF(num_elements_);
1231     Py_XDECREF(aggregate_fn_);
1232     Py_XDECREF(zeros_fn_);
1233     Py_XDECREF(zeros_like_fn_);
1234     Py_XDECREF(ones_fn_);
1235     Py_XDECREF(ones_like_fn_);
1236     Py_XDECREF(graph_shape_fn_);
1237 
1238     Py_DECREF(py_vspace_);
1239   }
1240 
NumElements(PyObject * tensor) const1241   tensorflow::int64 NumElements(PyObject* tensor) const final {
1242     if (EagerTensor_CheckExact(tensor)) {
1243       return PyEagerTensor_NumElements(tensor);
1244     }
1245     PyObject* arglist =
1246         Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1247     PyObject* result = PyEval_CallObject(num_elements_, arglist);
1248     Py_DECREF(arglist);
1249     if (result == nullptr) {
1250       // The caller detects whether a python exception has been raised.
1251       return -1;
1252     }
1253     tensorflow::int64 r = MakeInt(result);
1254     Py_DECREF(result);
1255     return r;
1256   }
1257 
AggregateGradients(tensorflow::gtl::ArraySlice<PyObject * > gradient_tensors) const1258   PyObject* AggregateGradients(
1259       tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1260     PyObject* list = PyList_New(gradient_tensors.size());
1261     for (int i = 0; i < gradient_tensors.size(); ++i) {
1262       // Note: stealing a reference to the gradient tensors.
1263       CHECK(gradient_tensors[i] != nullptr);
1264       CHECK(gradient_tensors[i] != Py_None);
1265       PyList_SET_ITEM(list, i,
1266                       reinterpret_cast<PyObject*>(gradient_tensors[i]));
1267     }
1268     PyObject* arglist = Py_BuildValue("(O)", list);
1269     CHECK(arglist != nullptr);
1270     PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1271     Py_DECREF(arglist);
1272     Py_DECREF(list);
1273     return result;
1274   }
1275 
TensorId(PyObject * tensor) const1276   tensorflow::int64 TensorId(PyObject* tensor) const final {
1277     return FastTensorId(tensor);
1278   }
1279 
MarkAsResult(PyObject * gradient) const1280   void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1281 
Ones(PyObject * shape,PyObject * dtype) const1282   PyObject* Ones(PyObject* shape, PyObject* dtype) const {
1283     if (PyErr_Occurred()) {
1284       return nullptr;
1285     }
1286     PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1287     PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1288     Py_DECREF(arg_list);
1289     return result;
1290   }
1291 
OnesLike(PyObject * tensor) const1292   PyObject* OnesLike(PyObject* tensor) const {
1293     if (PyErr_Occurred()) {
1294       return nullptr;
1295     }
1296     return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
1297   }
1298 
1299   // Builds a tensor filled with ones with the same shape and dtype as `t`.
BuildOnesLike(const PyTapeTensor & t,PyObject ** result) const1300   Status BuildOnesLike(const PyTapeTensor& t,
1301                        PyObject** result) const override {
1302     *result = t.OnesLike();
1303     return Status::OK();
1304   }
1305 
Zeros(PyObject * shape,PyObject * dtype) const1306   PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
1307     if (PyErr_Occurred()) {
1308       return nullptr;
1309     }
1310     PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1311     PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1312     Py_DECREF(arg_list);
1313     return result;
1314   }
1315 
ZerosLike(PyObject * tensor) const1316   PyObject* ZerosLike(PyObject* tensor) const {
1317     if (PyErr_Occurred()) {
1318       return nullptr;
1319     }
1320     return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
1321   }
1322 
GraphShape(PyObject * tensor) const1323   PyObject* GraphShape(PyObject* tensor) const {
1324     PyObject* arg_list = Py_BuildValue("(O)", tensor);
1325     PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1326     Py_DECREF(arg_list);
1327     return result;
1328   }
1329 
CallBackwardFunction(const string & op_type,PyBackwardFunction * backward_function,const std::vector<tensorflow::int64> & unneeded_gradients,tensorflow::gtl::ArraySlice<PyObject * > output_gradients,absl::Span<PyObject * > result) const1330   tensorflow::Status CallBackwardFunction(
1331       const string& op_type, PyBackwardFunction* backward_function,
1332       const std::vector<tensorflow::int64>& unneeded_gradients,
1333       tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1334       absl::Span<PyObject*> result) const final {
1335     PyObject* grads = PyTuple_New(output_gradients.size());
1336     for (int i = 0; i < output_gradients.size(); ++i) {
1337       if (output_gradients[i] == nullptr) {
1338         Py_INCREF(Py_None);
1339         PyTuple_SET_ITEM(grads, i, Py_None);
1340       } else {
1341         PyTuple_SET_ITEM(grads, i,
1342                          reinterpret_cast<PyObject*>(output_gradients[i]));
1343       }
1344     }
1345     PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
1346     Py_DECREF(grads);
1347     if (py_result == nullptr) {
1348       return tensorflow::errors::Internal("gradient function threw exceptions");
1349     }
1350     PyObject* seq =
1351         PySequence_Fast(py_result, "expected a sequence of gradients");
1352     if (seq == nullptr) {
1353       return tensorflow::errors::InvalidArgument(
1354           "gradient function did not return a list");
1355     }
1356     int len = PySequence_Fast_GET_SIZE(seq);
1357     if (len != result.size()) {
1358       return tensorflow::errors::Internal(
1359           "Recorded operation '", op_type,
1360           "' returned too few gradients. Expected ", result.size(),
1361           " but received ", len);
1362     }
1363     PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1364     VLOG(1) << "Gradient length is " << len;
1365     for (int i = 0; i < len; ++i) {
1366       PyObject* item = seq_array[i];
1367       if (item == Py_None) {
1368         result[i] = nullptr;
1369       } else {
1370         Py_INCREF(item);
1371         result[i] = item;
1372       }
1373     }
1374     Py_DECREF(seq);
1375     Py_DECREF(py_result);
1376     return tensorflow::Status::OK();
1377   }
1378 
DeleteGradient(PyObject * tensor) const1379   void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1380 
TapeTensorFromGradient(PyObject * tensor) const1381   PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
1382     return TapeTensorFromTensor(tensor);
1383   }
1384 
1385  private:
1386   PyObject* py_vspace_;
1387 
1388   PyObject* num_elements_;
1389   PyObject* aggregate_fn_;
1390   PyObject* zeros_fn_;
1391   PyObject* zeros_like_fn_;
1392   PyObject* ones_fn_;
1393   PyObject* ones_like_fn_;
1394   PyObject* graph_shape_fn_;
1395 };
1396 PyVSpace* py_vspace = nullptr;
1397 
1398 bool HasAccumulator();
1399 
TFE_Py_RegisterVSpace(PyObject * e)1400 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1401   if (py_vspace != nullptr) {
1402     if (HasAccumulator()) {
1403       // Accumulators reference py_vspace, so we can't swap it out while one is
1404       // active. This is unlikely to ever happen.
1405       MaybeRaiseExceptionFromStatus(
1406           tensorflow::errors::Internal(
1407               "Can't change the vspace implementation while a "
1408               "forward accumulator is active."),
1409           nullptr);
1410     }
1411     delete py_vspace;
1412   }
1413 
1414   py_vspace = new PyVSpace(e);
1415   auto status = py_vspace->Initialize();
1416   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1417     delete py_vspace;
1418     return nullptr;
1419   }
1420 
1421   Py_RETURN_NONE;
1422 }
1423 
GetShape() const1424 PyObject* PyTapeTensor::GetShape() const {
1425   if (shape_.index() == 0) {
1426     auto& shape = absl::get<0>(shape_);
1427     PyObject* py_shape = PyTuple_New(shape.dims());
1428     for (int i = 0; i < shape.dims(); ++i) {
1429       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1430     }
1431 
1432     return py_shape;
1433   }
1434 
1435   return py_vspace->GraphShape(absl::get<1>(shape_));
1436 }
1437 
OnesLike() const1438 PyObject* PyTapeTensor::OnesLike() const {
1439   if (shape_.index() == 1) {
1440     PyObject* tensor = absl::get<1>(shape_);
1441     return py_vspace->OnesLike(tensor);
1442   }
1443   PyObject* py_shape = GetShape();
1444   PyObject* dtype_field = GetPyDType();
1445   PyObject* result = py_vspace->Ones(py_shape, dtype_field);
1446   Py_DECREF(dtype_field);
1447   Py_DECREF(py_shape);
1448   return result;
1449 }
1450 
ZerosLike() const1451 PyObject* PyTapeTensor::ZerosLike() const {
1452   if (shape_.index() == 1) {
1453     PyObject* tensor = absl::get<1>(shape_);
1454     return py_vspace->ZerosLike(tensor);
1455   }
1456   PyObject* py_shape = GetShape();
1457   PyObject* dtype_field = GetPyDType();
1458   PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
1459   Py_DECREF(dtype_field);
1460   Py_DECREF(py_shape);
1461   return result;
1462 }
1463 
1464 // Keeps track of all variables that have been accessed during execution.
1465 class VariableWatcher {
1466  public:
VariableWatcher()1467   VariableWatcher() {}
1468 
~VariableWatcher()1469   ~VariableWatcher() {
1470     for (const IdAndVariable& v : watched_variables_) {
1471       Py_DECREF(v.variable);
1472     }
1473   }
1474 
WatchVariable(PyObject * v)1475   tensorflow::int64 WatchVariable(PyObject* v) {
1476     tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1477     if (handle == nullptr) {
1478       return -1;
1479     }
1480     tensorflow::int64 id = FastTensorId(handle.get());
1481 
1482     tensorflow::mutex_lock l(watched_variables_mu_);
1483     auto insert_result = watched_variables_.emplace(id, v);
1484 
1485     if (insert_result.second) {
1486       // Only increment the reference count if we aren't already watching this
1487       // variable.
1488       Py_INCREF(v);
1489     }
1490 
1491     return id;
1492   }
1493 
GetVariablesAsPyTuple()1494   PyObject* GetVariablesAsPyTuple() {
1495     tensorflow::mutex_lock l(watched_variables_mu_);
1496     PyObject* result = PyTuple_New(watched_variables_.size());
1497     Py_ssize_t pos = 0;
1498     for (const IdAndVariable& id_and_variable : watched_variables_) {
1499       PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1500       Py_INCREF(id_and_variable.variable);
1501     }
1502     return result;
1503   }
1504 
1505  private:
1506   // We store an IdAndVariable in the map since the map needs to be locked
1507   // during insert, but should not call back into python during insert to avoid
1508   // deadlocking with the GIL.
1509   struct IdAndVariable {
1510     tensorflow::int64 id;
1511     PyObject* variable;
1512 
IdAndVariableVariableWatcher::IdAndVariable1513     IdAndVariable(tensorflow::int64 id, PyObject* variable)
1514         : id(id), variable(variable) {}
1515   };
1516   struct CompareById {
operator ()VariableWatcher::CompareById1517     bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1518       return lhs.id < rhs.id;
1519     }
1520   };
1521 
1522   tensorflow::mutex watched_variables_mu_;
1523   std::set<IdAndVariable, CompareById> watched_variables_
1524       TF_GUARDED_BY(watched_variables_mu_);
1525 };
1526 
1527 class GradientTape
1528     : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1529                                              PyTapeTensor> {
1530  public:
GradientTape(bool persistent,bool watch_accessed_variables)1531   explicit GradientTape(bool persistent, bool watch_accessed_variables)
1532       : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1533                                         PyTapeTensor>(persistent),
1534         watch_accessed_variables_(watch_accessed_variables) {}
1535 
~GradientTape()1536   virtual ~GradientTape() {}
1537 
VariableAccessed(PyObject * v)1538   void VariableAccessed(PyObject* v) {
1539     if (watch_accessed_variables_) {
1540       WatchVariable(v);
1541     }
1542   }
1543 
WatchVariable(PyObject * v)1544   void WatchVariable(PyObject* v) {
1545     tensorflow::int64 id = variable_watcher_.WatchVariable(v);
1546 
1547     if (!PyErr_Occurred()) {
1548       this->Watch(id);
1549     }
1550   }
1551 
GetVariablesAsPyTuple()1552   PyObject* GetVariablesAsPyTuple() {
1553     return variable_watcher_.GetVariablesAsPyTuple();
1554   }
1555 
1556  private:
1557   bool watch_accessed_variables_;
1558   VariableWatcher variable_watcher_;
1559 };
1560 
1561 typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
1562                                               PyTapeTensor>
1563     ForwardAccumulator;
1564 
1565 // Incremented when a GradientTape or accumulator is newly added to a set, and
1566 // used to enforce an ordering between them.
1567 std::atomic_uint_fast64_t tape_nesting_id_counter(0);
1568 
1569 typedef struct {
1570   PyObject_HEAD
1571       /* Type-specific fields go here. */
1572       GradientTape* tape;
1573   // A nesting order between GradientTapes and ForwardAccumulators, used to
1574   // ensure that GradientTapes do not watch the products of outer
1575   // ForwardAccumulators.
1576   tensorflow::int64 nesting_id;
1577 } TFE_Py_Tape;
1578 
TFE_Py_Tape_Delete(PyObject * tape)1579 static void TFE_Py_Tape_Delete(PyObject* tape) {
1580   delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1581   Py_TYPE(tape)->tp_free(tape);
1582 }
1583 
1584 static PyTypeObject TFE_Py_Tape_Type = {
1585     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1586     sizeof(TFE_Py_Tape),                          /* tp_basicsize */
1587     0,                                            /* tp_itemsize */
1588     &TFE_Py_Tape_Delete,                          /* tp_dealloc */
1589 #if PY_VERSION_HEX < 0x03080000
1590     nullptr, /* tp_print */
1591 #else
1592     0, /* tp_vectorcall_offset */
1593 #endif
1594     nullptr,               /* tp_getattr */
1595     nullptr,               /* tp_setattr */
1596     nullptr,               /* tp_reserved */
1597     nullptr,               /* tp_repr */
1598     nullptr,               /* tp_as_number */
1599     nullptr,               /* tp_as_sequence */
1600     nullptr,               /* tp_as_mapping */
1601     nullptr,               /* tp_hash  */
1602     nullptr,               /* tp_call */
1603     nullptr,               /* tp_str */
1604     nullptr,               /* tp_getattro */
1605     nullptr,               /* tp_setattro */
1606     nullptr,               /* tp_as_buffer */
1607     Py_TPFLAGS_DEFAULT,    /* tp_flags */
1608     "TFE_Py_Tape objects", /* tp_doc */
1609 };
1610 
1611 typedef struct {
1612   PyObject_HEAD
1613       /* Type-specific fields go here. */
1614       ForwardAccumulator* accumulator;
1615   // A nesting order between GradientTapes and ForwardAccumulators, used to
1616   // ensure that GradientTapes do not watch the products of outer
1617   // ForwardAccumulators.
1618   tensorflow::int64 nesting_id;
1619 } TFE_Py_ForwardAccumulator;
1620 
TFE_Py_ForwardAccumulatorDelete(PyObject * accumulator)1621 static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
1622   delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
1623   Py_TYPE(accumulator)->tp_free(accumulator);
1624 }
1625 
1626 static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
1627     PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
1628     sizeof(TFE_Py_ForwardAccumulator),                      /* tp_basicsize */
1629     0,                                                      /* tp_itemsize */
1630     &TFE_Py_ForwardAccumulatorDelete,                       /* tp_dealloc */
1631 #if PY_VERSION_HEX < 0x03080000
1632     nullptr, /* tp_print */
1633 #else
1634     0, /* tp_vectorcall_offset */
1635 #endif
1636     nullptr,                             /* tp_getattr */
1637     nullptr,                             /* tp_setattr */
1638     nullptr,                             /* tp_reserved */
1639     nullptr,                             /* tp_repr */
1640     nullptr,                             /* tp_as_number */
1641     nullptr,                             /* tp_as_sequence */
1642     nullptr,                             /* tp_as_mapping */
1643     nullptr,                             /* tp_hash  */
1644     nullptr,                             /* tp_call */
1645     nullptr,                             /* tp_str */
1646     nullptr,                             /* tp_getattro */
1647     nullptr,                             /* tp_setattro */
1648     nullptr,                             /* tp_as_buffer */
1649     Py_TPFLAGS_DEFAULT,                  /* tp_flags */
1650     "TFE_Py_ForwardAccumulator objects", /* tp_doc */
1651 };
1652 
1653 typedef struct {
1654   PyObject_HEAD
1655       /* Type-specific fields go here. */
1656       VariableWatcher* variable_watcher;
1657 } TFE_Py_VariableWatcher;
1658 
TFE_Py_VariableWatcher_Delete(PyObject * variable_watcher)1659 static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
1660   delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
1661       ->variable_watcher;
1662   Py_TYPE(variable_watcher)->tp_free(variable_watcher);
1663 }
1664 
1665 static PyTypeObject TFE_Py_VariableWatcher_Type = {
1666     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
1667     sizeof(TFE_Py_VariableWatcher),                          /* tp_basicsize */
1668     0,                                                       /* tp_itemsize */
1669     &TFE_Py_VariableWatcher_Delete,                          /* tp_dealloc */
1670 #if PY_VERSION_HEX < 0x03080000
1671     nullptr, /* tp_print */
1672 #else
1673     0, /* tp_vectorcall_offset */
1674 #endif
1675     nullptr,                          /* tp_getattr */
1676     nullptr,                          /* tp_setattr */
1677     nullptr,                          /* tp_reserved */
1678     nullptr,                          /* tp_repr */
1679     nullptr,                          /* tp_as_number */
1680     nullptr,                          /* tp_as_sequence */
1681     nullptr,                          /* tp_as_mapping */
1682     nullptr,                          /* tp_hash  */
1683     nullptr,                          /* tp_call */
1684     nullptr,                          /* tp_str */
1685     nullptr,                          /* tp_getattro */
1686     nullptr,                          /* tp_setattro */
1687     nullptr,                          /* tp_as_buffer */
1688     Py_TPFLAGS_DEFAULT,               /* tp_flags */
1689     "TFE_Py_VariableWatcher objects", /* tp_doc */
1690 };
1691 
1692 // Note: in the current design no mutex is needed here because of the python
1693 // GIL, which is always held when any TFE_Py_* methods are called. We should
1694 // revisit this if/when decide to not hold the GIL while manipulating the tape
1695 // stack.
GetTapeSet()1696 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1697   thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
1698       tape_set = nullptr;
1699   if (tape_set == nullptr) {
1700     tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
1701   }
1702   return tape_set.get();
1703 }
1704 
1705 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
GetVariableWatcherSet()1706 GetVariableWatcherSet() {
1707   thread_local std::unique_ptr<
1708       tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
1709       variable_watcher_set = nullptr;
1710   if (variable_watcher_set == nullptr) {
1711     variable_watcher_set.reset(
1712         new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
1713   }
1714   return variable_watcher_set.get();
1715 }
1716 
1717 // A linked hash set, where iteration is in insertion order.
1718 //
1719 // Nested accumulators rely on op recording happening in insertion order, so an
1720 // unordered data structure like CompactPointerSet is not suitable. Outer
1721 // accumulators need to observe operations first so they know to watch the inner
1722 // accumulator's jvp computation.
1723 //
1724 // Not thread safe.
1725 class AccumulatorSet {
1726  public:
1727   // Returns true if `element` was newly inserted, false if it already exists.
insert(TFE_Py_ForwardAccumulator * element)1728   bool insert(TFE_Py_ForwardAccumulator* element) {
1729     if (map_.find(element) != map_.end()) {
1730       return false;
1731     }
1732     ListType::iterator it = ordered_.insert(ordered_.end(), element);
1733     map_.insert(std::make_pair(element, it));
1734     return true;
1735   }
1736 
erase(TFE_Py_ForwardAccumulator * element)1737   void erase(TFE_Py_ForwardAccumulator* element) {
1738     MapType::iterator existing = map_.find(element);
1739     if (existing == map_.end()) {
1740       return;
1741     }
1742     ListType::iterator list_position = existing->second;
1743     map_.erase(existing);
1744     ordered_.erase(list_position);
1745   }
1746 
empty() const1747   bool empty() const { return ordered_.empty(); }
1748 
size() const1749   size_t size() const { return ordered_.size(); }
1750 
1751  private:
1752   typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
1753   typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
1754                                    ListType::iterator>
1755       MapType;
1756 
1757  public:
1758   typedef ListType::const_iterator const_iterator;
1759   typedef ListType::const_reverse_iterator const_reverse_iterator;
1760 
begin() const1761   const_iterator begin() const { return ordered_.begin(); }
end() const1762   const_iterator end() const { return ordered_.end(); }
1763 
rbegin() const1764   const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
rend() const1765   const_reverse_iterator rend() const { return ordered_.rend(); }
1766 
1767  private:
1768   MapType map_;
1769   ListType ordered_;
1770 };
1771 
GetAccumulatorSet()1772 AccumulatorSet* GetAccumulatorSet() {
1773   thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
1774   if (accumulator_set == nullptr) {
1775     accumulator_set.reset(new AccumulatorSet);
1776   }
1777   return accumulator_set.get();
1778 }
1779 
HasAccumulator()1780 inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
1781 
HasGradientTape()1782 inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
1783 
HasAccumulatorOrTape()1784 inline bool HasAccumulatorOrTape() {
1785   return HasGradientTape() || HasAccumulator();
1786 }
1787 
1788 // A safe copy of a set, used for tapes and accumulators. The copy is not
1789 // affected by other python threads changing the set of active tapes.
1790 template <typename ContainerType>
1791 class SafeSetCopy {
1792  public:
SafeSetCopy(const ContainerType & to_copy)1793   explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
1794     for (auto* member : set_copy_) {
1795       Py_INCREF(member);
1796     }
1797   }
1798 
~SafeSetCopy()1799   ~SafeSetCopy() {
1800     for (auto* member : set_copy_) {
1801       Py_DECREF(member);
1802     }
1803   }
1804 
begin() const1805   typename ContainerType::const_iterator begin() const {
1806     return set_copy_.begin();
1807   }
1808 
end() const1809   typename ContainerType::const_iterator end() const { return set_copy_.end(); }
1810 
empty() const1811   bool empty() const { return set_copy_.empty(); }
size() const1812   size_t size() const { return set_copy_.size(); }
1813 
1814  protected:
1815   ContainerType set_copy_;
1816 };
1817 
1818 class SafeTapeSet
1819     : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
1820  public:
SafeTapeSet()1821   SafeTapeSet()
1822       : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
1823             *GetTapeSet()) {}
1824 };
1825 
1826 class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
1827  public:
SafeAccumulatorSet()1828   SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
1829 
rbegin() const1830   typename AccumulatorSet::const_reverse_iterator rbegin() const {
1831     return set_copy_.rbegin();
1832   }
1833 
rend() const1834   typename AccumulatorSet::const_reverse_iterator rend() const {
1835     return set_copy_.rend();
1836   }
1837 };
1838 
1839 class SafeVariableWatcherSet
1840     : public SafeSetCopy<
1841           tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
1842  public:
SafeVariableWatcherSet()1843   SafeVariableWatcherSet()
1844       : SafeSetCopy<
1845             tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
1846             *GetVariableWatcherSet()) {}
1847 };
1848 
ThreadTapeIsStopped()1849 bool* ThreadTapeIsStopped() {
1850   thread_local bool thread_tape_is_stopped{false};
1851   return &thread_tape_is_stopped;
1852 }
1853 
TFE_Py_TapeSetStopOnThread()1854 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1855 
TFE_Py_TapeSetRestartOnThread()1856 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1857 
TFE_Py_TapeSetIsStopped()1858 PyObject* TFE_Py_TapeSetIsStopped() {
1859   if (*ThreadTapeIsStopped()) {
1860     Py_RETURN_TRUE;
1861   }
1862   Py_RETURN_FALSE;
1863 }
1864 
TFE_Py_TapeSetNew(PyObject * persistent,PyObject * watch_accessed_variables)1865 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1866                             PyObject* watch_accessed_variables) {
1867   TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1868   if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1869   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1870   tape->tape = new GradientTape(persistent == Py_True,
1871                                 watch_accessed_variables == Py_True);
1872   Py_INCREF(tape);
1873   tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1874   GetTapeSet()->insert(tape);
1875   return reinterpret_cast<PyObject*>(tape);
1876 }
1877 
TFE_Py_TapeSetAdd(PyObject * tape)1878 void TFE_Py_TapeSetAdd(PyObject* tape) {
1879   Py_INCREF(tape);
1880   TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
1881   if (!GetTapeSet()->insert(tfe_tape).second) {
1882     // Already exists in the tape set.
1883     Py_DECREF(tape);
1884   } else {
1885     tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1886   }
1887 }
1888 
TFE_Py_TapeSetIsEmpty()1889 PyObject* TFE_Py_TapeSetIsEmpty() {
1890   if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
1891     Py_RETURN_TRUE;
1892   }
1893   Py_RETURN_FALSE;
1894 }
1895 
TFE_Py_TapeSetRemove(PyObject * tape)1896 void TFE_Py_TapeSetRemove(PyObject* tape) {
1897   auto* stack = GetTapeSet();
1898   stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1899   // We kept a reference to the tape in the set to ensure it wouldn't get
1900   // deleted under us; cleaning it up here.
1901   Py_DECREF(tape);
1902 }
1903 
MakeIntList(PyObject * list)1904 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
1905   if (list == Py_None) {
1906     return {};
1907   }
1908   PyObject* seq = PySequence_Fast(list, "expected a sequence");
1909   if (seq == nullptr) {
1910     return {};
1911   }
1912   int len = PySequence_Size(list);
1913   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1914   std::vector<tensorflow::int64> tensor_ids;
1915   tensor_ids.reserve(len);
1916   for (int i = 0; i < len; ++i) {
1917     PyObject* item = seq_array[i];
1918 #if PY_MAJOR_VERSION >= 3
1919     if (PyLong_Check(item)) {
1920 #else
1921     if (PyLong_Check(item) || PyInt_Check(item)) {
1922 #endif
1923       tensorflow::int64 id = MakeInt(item);
1924       tensor_ids.push_back(id);
1925     } else {
1926       tensor_ids.push_back(-1);
1927     }
1928   }
1929   Py_DECREF(seq);
1930   return tensor_ids;
1931 }
1932 
1933 // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
1934 // null. Returns true on success and false on a Python exception.
1935 bool TensorShapesAndDtypes(PyObject* tensors,
1936                            std::vector<tensorflow::int64>* tensor_ids,
1937                            std::vector<tensorflow::DataType>* dtypes) {
1938   tensorflow::Safe_PyObjectPtr seq(
1939       PySequence_Fast(tensors, "expected a sequence"));
1940   if (seq == nullptr) {
1941     return false;
1942   }
1943   int len = PySequence_Fast_GET_SIZE(seq.get());
1944   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
1945   tensor_ids->reserve(len);
1946   dtypes->reserve(len);
1947   for (int i = 0; i < len; ++i) {
1948     PyObject* item = seq_array[i];
1949     tensor_ids->push_back(FastTensorId(item));
1950     dtypes->push_back(tensorflow::PyTensor_DataType(item));
1951   }
1952   return true;
1953 }
1954 
1955 bool TapeCouldPossiblyRecord(PyObject* tensors) {
1956   if (tensors == Py_None) {
1957     return false;
1958   }
1959   if (*ThreadTapeIsStopped()) {
1960     return false;
1961   }
1962   if (!HasAccumulatorOrTape()) {
1963     return false;
1964   }
1965   return true;
1966 }
1967 
1968 bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
1969 
1970 bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
1971 
1972 PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
1973   if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
1974     Py_RETURN_FALSE;
1975   }
1976   // TODO(apassos) consider not building a list and changing the API to check
1977   // each tensor individually.
1978   std::vector<tensorflow::int64> tensor_ids;
1979   std::vector<tensorflow::DataType> dtypes;
1980   if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
1981     return nullptr;
1982   }
1983   auto tape_set = *GetTapeSet();
1984   for (TFE_Py_Tape* tape : tape_set) {
1985     if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
1986       Py_RETURN_TRUE;
1987     }
1988   }
1989 
1990   Py_RETURN_FALSE;
1991 }
1992 
1993 PyObject* TFE_Py_ForwardAccumulatorPushState() {
1994   auto forward_accumulators = *GetAccumulatorSet();
1995   for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
1996     accumulator->accumulator->PushState();
1997   }
1998   Py_RETURN_NONE;
1999 }
2000 
2001 PyObject* TFE_Py_ForwardAccumulatorPopState() {
2002   auto forward_accumulators = *GetAccumulatorSet();
2003   for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2004     accumulator->accumulator->PopState();
2005   }
2006   Py_RETURN_NONE;
2007 }
2008 
2009 PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
2010   if (!TapeCouldPossiblyRecord(tensors)) {
2011     return GetPythonObjectFromInt(0);
2012   }
2013   std::vector<tensorflow::int64> tensor_ids;
2014   std::vector<tensorflow::DataType> dtypes;
2015   if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2016     return nullptr;
2017   }
2018 
2019   // If there is a persistent tape watching, or if there are multiple tapes
2020   // watching, we'll return immediately indicating that higher-order tape
2021   // gradients are possible.
2022   bool some_tape_watching = false;
2023   if (CouldBackprop()) {
2024     auto tape_set = *GetTapeSet();
2025     for (TFE_Py_Tape* tape : tape_set) {
2026       if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2027         if (tape->tape->IsPersistent() || some_tape_watching) {
2028           // Either this is the second tape watching, or this tape is
2029           // persistent: higher-order gradients are possible.
2030           return GetPythonObjectFromInt(2);
2031         }
2032         some_tape_watching = true;
2033       }
2034     }
2035   }
2036   if (CouldForwardprop()) {
2037     auto forward_accumulators = *GetAccumulatorSet();
2038     for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2039       if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
2040         if (some_tape_watching) {
2041           // This is the second tape watching: higher-order gradients are
2042           // possible. Note that there's no equivalent of persistence for
2043           // forward-mode.
2044           return GetPythonObjectFromInt(2);
2045         }
2046         some_tape_watching = true;
2047       }
2048     }
2049   }
2050   if (some_tape_watching) {
2051     // There's exactly one non-persistent tape. The user can request first-order
2052     // gradients but won't be able to get higher-order tape gradients.
2053     return GetPythonObjectFromInt(1);
2054   } else {
2055     // There are no tapes. The user can't request tape gradients.
2056     return GetPythonObjectFromInt(0);
2057   }
2058 }
2059 
2060 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
2061   if (!CouldBackprop()) {
2062     return;
2063   }
2064   tensorflow::int64 tensor_id = FastTensorId(tensor);
2065   if (PyErr_Occurred()) {
2066     return;
2067   }
2068   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
2069 }
2070 
2071 bool ListContainsNone(PyObject* list) {
2072   if (list == Py_None) return true;
2073   tensorflow::Safe_PyObjectPtr seq(
2074       PySequence_Fast(list, "expected a sequence"));
2075   if (seq == nullptr) {
2076     return false;
2077   }
2078 
2079   int len = PySequence_Size(list);
2080   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2081   for (int i = 0; i < len; ++i) {
2082     PyObject* item = seq_array[i];
2083     if (item == Py_None) return true;
2084   }
2085 
2086   return false;
2087 }
2088 
2089 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
2090   if (EagerTensor_CheckExact(tensor)) {
2091     tensorflow::ImmediateExecutionTensorHandle* handle =
2092         tensorflow::unwrap(EagerTensor_Handle(tensor));
2093     tensorflow::int64 id = PyEagerTensor_ID(tensor);
2094     tensorflow::DataType dtype =
2095         static_cast<tensorflow::DataType>(handle->DataType());
2096     if (dtype == tensorflow::DT_VARIANT) {
2097       return PyTapeTensor(id, dtype, tensor);
2098     }
2099 
2100     tensorflow::TensorShape tensor_shape;
2101     int num_dims;
2102     tensorflow::Status status = handle->NumDims(&num_dims);
2103     if (status.ok()) {
2104       for (int i = 0; i < num_dims; ++i) {
2105         tensorflow::int64 dim_size;
2106         status = handle->Dim(i, &dim_size);
2107         if (!status.ok()) break;
2108         tensor_shape.AddDim(dim_size);
2109       }
2110     }
2111 
2112     if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2113       return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2114                           tensorflow::TensorShape({}));
2115     } else {
2116       return PyTapeTensor(id, dtype, tensor_shape);
2117     }
2118   }
2119   tensorflow::int64 id = FastTensorId(tensor);
2120   if (PyErr_Occurred()) {
2121     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2122                         tensorflow::TensorShape({}));
2123   }
2124   PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
2125   PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
2126   Py_DECREF(dtype_object);
2127   tensorflow::DataType dtype =
2128       static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
2129   Py_DECREF(dtype_enum);
2130   if (PyErr_Occurred()) {
2131     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2132                         tensorflow::TensorShape({}));
2133   }
2134   static char _shape_tuple[] = "_shape_tuple";
2135   tensorflow::Safe_PyObjectPtr shape_tuple(
2136       PyObject_CallMethod(tensor, _shape_tuple, nullptr));
2137   if (PyErr_Occurred()) {
2138     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2139                         tensorflow::TensorShape({}));
2140   }
2141 
2142   if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) {
2143     return PyTapeTensor(id, dtype, tensor);
2144   }
2145 
2146   auto l = MakeIntList(shape_tuple.get());
2147   // Replace -1, which represents accidental Nones which can occur in graph mode
2148   // and can cause errors in shape construction with 0s.
2149   for (auto& c : l) {
2150     if (c < 0) {
2151       c = 0;
2152     }
2153   }
2154   tensorflow::TensorShape shape(l);
2155   return PyTapeTensor(id, dtype, shape);
2156 }
2157 
2158 // Populates output_info from output_seq, which must come from PySequence_Fast.
2159 //
2160 // Does not take ownership of output_seq. Returns true on success and false if a
2161 // Python exception has been set.
2162 bool TapeTensorsFromTensorSequence(PyObject* output_seq,
2163                                    std::vector<PyTapeTensor>* output_info) {
2164   Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
2165   PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2166   output_info->reserve(output_len);
2167   for (Py_ssize_t i = 0; i < output_len; ++i) {
2168     output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
2169     if (PyErr_Occurred() != nullptr) {
2170       return false;
2171     }
2172   }
2173   return true;
2174 }
2175 
2176 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
2177   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2178   if (seq == nullptr) {
2179     return {};
2180   }
2181   int len = PySequence_Fast_GET_SIZE(seq);
2182   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2183   std::vector<tensorflow::int64> list;
2184   list.reserve(len);
2185   for (int i = 0; i < len; ++i) {
2186     PyObject* tensor = seq_array[i];
2187     list.push_back(FastTensorId(tensor));
2188     if (PyErr_Occurred()) {
2189       Py_DECREF(seq);
2190       return list;
2191     }
2192   }
2193   Py_DECREF(seq);
2194   return list;
2195 }
2196 
2197 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
2198   if (!CouldBackprop()) {
2199     return;
2200   }
2201   for (TFE_Py_Tape* tape : SafeTapeSet()) {
2202     tape->tape->VariableAccessed(variable);
2203   }
2204 }
2205 
2206 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
2207   if (!CouldBackprop()) {
2208     return;
2209   }
2210   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
2211 }
2212 
2213 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
2214   return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
2215 }
2216 
2217 PyObject* TFE_Py_VariableWatcherNew() {
2218   TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
2219   if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
2220   TFE_Py_VariableWatcher* variable_watcher =
2221       PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
2222   variable_watcher->variable_watcher = new VariableWatcher();
2223   Py_INCREF(variable_watcher);
2224   GetVariableWatcherSet()->insert(variable_watcher);
2225   return reinterpret_cast<PyObject*>(variable_watcher);
2226 }
2227 
2228 void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
2229   auto* stack = GetVariableWatcherSet();
2230   stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
2231   // We kept a reference to the variable watcher in the set to ensure it
2232   // wouldn't get deleted under us; cleaning it up here.
2233   Py_DECREF(variable_watcher);
2234 }
2235 
2236 void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
2237   for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
2238     variable_watcher->variable_watcher->WatchVariable(variable);
2239   }
2240 }
2241 
2242 PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
2243   return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
2244       ->variable_watcher->GetVariablesAsPyTuple();
2245 }
2246 
2247 namespace {
2248 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
2249   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2250   if (seq == nullptr) {
2251     return {};
2252   }
2253   int len = PySequence_Fast_GET_SIZE(seq);
2254   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2255   std::vector<tensorflow::DataType> list;
2256   list.reserve(len);
2257   for (int i = 0; i < len; ++i) {
2258     PyObject* tensor = seq_array[i];
2259     list.push_back(tensorflow::PyTensor_DataType(tensor));
2260   }
2261   Py_DECREF(seq);
2262   return list;
2263 }
2264 
2265 PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
2266                                            PyObject* weak_tensor_ref) {
2267   tensorflow::int64 parsed_tensor_id = MakeInt(tensor_id);
2268   for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
2269     accumulator->accumulator->DeleteGradient(parsed_tensor_id);
2270   }
2271   Py_DECREF(weak_tensor_ref);
2272   Py_DECREF(tensor_id);
2273   Py_INCREF(Py_None);
2274   return Py_None;
2275 }
2276 
2277 static PyMethodDef forward_accumulator_delete_gradient_method_def = {
2278     "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
2279     METH_O, "ForwardAccumulatorDeleteGradient"};
2280 
2281 void RegisterForwardAccumulatorCleanup(PyObject* tensor,
2282                                        tensorflow::int64 tensor_id) {
2283   tensorflow::Safe_PyObjectPtr callback(
2284       PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
2285                       PyLong_FromLong(tensor_id)));
2286   // We need to keep a reference to the weakref active if we want our callback
2287   // called. The callback itself now owns the weakref object and the tensor ID
2288   // object.
2289   PyWeakref_NewRef(tensor, callback.get());
2290 }
2291 
2292 void TapeSetRecordBackprop(
2293     const string& op_type, const std::vector<PyTapeTensor>& output_info,
2294     const std::vector<tensorflow::int64>& input_ids,
2295     const std::vector<tensorflow::DataType>& input_dtypes,
2296     const std::function<PyBackwardFunction*()>& backward_function_getter,
2297     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2298     tensorflow::uint64 max_gradient_tape_id) {
2299   if (!CouldBackprop()) {
2300     return;
2301   }
2302   for (TFE_Py_Tape* tape : SafeTapeSet()) {
2303     if (tape->nesting_id < max_gradient_tape_id) {
2304       tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
2305                                   backward_function_getter,
2306                                   backward_function_killer);
2307     }
2308   }
2309 }
2310 
2311 bool TapeSetRecordForwardprop(
2312     const string& op_type, PyObject* output_seq,
2313     const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
2314     const std::vector<tensorflow::int64>& input_ids,
2315     const std::vector<tensorflow::DataType>& input_dtypes,
2316     const std::function<PyBackwardFunction*()>& backward_function_getter,
2317     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2318     const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
2319     PyObject* forwardprop_output_indices,
2320     tensorflow::uint64* max_gradient_tape_id) {
2321   *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
2322   if (!CouldForwardprop()) {
2323     return true;
2324   }
2325   auto accumulator_set = SafeAccumulatorSet();
2326   tensorflow::Safe_PyObjectPtr input_seq(
2327       PySequence_Fast(input_tensors, "expected a sequence of tensors"));
2328   if (input_seq == nullptr || PyErr_Occurred()) return false;
2329   Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
2330   PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2331   for (int i = 0; i < output_info.size(); ++i) {
2332     RegisterForwardAccumulatorCleanup(output_seq_array[i],
2333                                       output_info[i].GetID());
2334   }
2335   if (forwardprop_output_indices != nullptr &&
2336       forwardprop_output_indices != Py_None) {
2337     tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
2338         forwardprop_output_indices, "Expected a sequence of indices"));
2339     if (indices_fast == nullptr || PyErr_Occurred()) {
2340       return false;
2341     }
2342     if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
2343         accumulator_set.size()) {
2344       MaybeRaiseExceptionFromStatus(
2345           tensorflow::errors::Internal(
2346               "Accumulators were added or removed from the active set "
2347               "between packing and unpacking."),
2348           nullptr);
2349     }
2350     PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
2351     Py_ssize_t accumulator_index = 0;
2352     for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
2353          it != accumulator_set.rend(); ++it, ++accumulator_index) {
2354       tensorflow::Safe_PyObjectPtr jvp_index_seq(
2355           PySequence_Fast(indices_fast_array[accumulator_index],
2356                           "Expected a sequence of jvp indices."));
2357       if (jvp_index_seq == nullptr || PyErr_Occurred()) {
2358         return false;
2359       }
2360       Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
2361       PyObject** jvp_index_seq_array =
2362           PySequence_Fast_ITEMS(jvp_index_seq.get());
2363       for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
2364         PyObject* tuple = jvp_index_seq_array[jvp_index];
2365         tensorflow::int64 primal_tensor_id =
2366             output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
2367         (*it)->accumulator->Watch(
2368             primal_tensor_id,
2369             output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
2370       }
2371     }
2372   } else {
2373     std::vector<PyTapeTensor> input_info;
2374     input_info.reserve(input_len);
2375     PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
2376     for (Py_ssize_t i = 0; i < input_len; ++i) {
2377       input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
2378     }
2379     for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
2380       tensorflow::Status status = accumulator->accumulator->Accumulate(
2381           op_type, input_info, output_info, input_ids, input_dtypes,
2382           forward_function, backward_function_getter, backward_function_killer);
2383       if (PyErr_Occurred()) return false;  // Don't swallow Python exceptions.
2384       if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2385         return false;
2386       }
2387       if (accumulator->accumulator->BusyAccumulating()) {
2388         // Ensure inner accumulators don't see outer accumulators' jvps. This
2389         // mostly happens on its own, with some potentially surprising
2390         // exceptions, so the blanket policy is for consistency.
2391         *max_gradient_tape_id = accumulator->nesting_id;
2392         break;
2393       }
2394     }
2395   }
2396   return true;
2397 }
2398 
2399 PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
2400   PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
2401   for (int i = 0; i < input_tangents.size(); ++i) {
2402     PyObject* element;
2403     if (input_tangents[i] == nullptr) {
2404       element = Py_None;
2405     } else {
2406       element = input_tangents[i];
2407     }
2408     Py_INCREF(element);
2409     PyTuple_SET_ITEM(py_input_tangents, i, element);
2410   }
2411   return py_input_tangents;
2412 }
2413 
2414 tensorflow::Status ParseTangentOutputs(
2415     PyObject* user_output, std::vector<PyObject*>* output_tangents) {
2416   if (user_output == Py_None) {
2417     // No connected gradients.
2418     return tensorflow::Status::OK();
2419   }
2420   tensorflow::Safe_PyObjectPtr fast_result(
2421       PySequence_Fast(user_output, "expected a sequence of forward gradients"));
2422   if (fast_result == nullptr) {
2423     return tensorflow::errors::InvalidArgument(
2424         "forward gradient function did not return a sequence.");
2425   }
2426   int len = PySequence_Fast_GET_SIZE(fast_result.get());
2427   PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
2428   output_tangents->reserve(len);
2429   for (int i = 0; i < len; ++i) {
2430     PyObject* item = fast_result_array[i];
2431     if (item == Py_None) {
2432       output_tangents->push_back(nullptr);
2433     } else {
2434       Py_INCREF(item);
2435       output_tangents->push_back(item);
2436     }
2437   }
2438   return tensorflow::Status::OK();
2439 }
2440 
2441 // Calls the registered forward_gradient_function, computing `output_tangents`
2442 // from `input_tangents`. `output_tangents` must not be null.
2443 //
2444 // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
2445 // the forward function is being called.
2446 tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
2447                                    PyObject* inputs, PyObject* results,
2448                                    const std::vector<PyObject*>& input_tangents,
2449                                    std::vector<PyObject*>* output_tangents,
2450                                    bool use_batch) {
2451   if (forward_gradient_function == nullptr) {
2452     return tensorflow::errors::Internal(
2453         "No forward gradient function registered.");
2454   }
2455   tensorflow::Safe_PyObjectPtr py_input_tangents(
2456       TangentsAsPyTuple(input_tangents));
2457 
2458   // Normalize the input sequence to a tuple so it works with function
2459   // caching; otherwise it may be an opaque _InputList object.
2460   tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
2461   PyObject* to_batch = (use_batch) ? Py_True : Py_False;
2462   tensorflow::Safe_PyObjectPtr callback_args(
2463       Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
2464                     py_input_tangents.get(), to_batch));
2465   tensorflow::Safe_PyObjectPtr py_result(
2466       PyObject_CallObject(forward_gradient_function, callback_args.get()));
2467   if (py_result == nullptr || PyErr_Occurred()) {
2468     return tensorflow::errors::Internal(
2469         "forward gradient function threw exceptions");
2470   }
2471   return ParseTangentOutputs(py_result.get(), output_tangents);
2472 }
2473 
2474 // Like CallJVPFunction, but calls a pre-bound forward function.
2475 // These are passed in from a record_gradient argument.
2476 tensorflow::Status CallOpSpecificJVPFunction(
2477     PyObject* op_specific_forward_function,
2478     const std::vector<PyObject*>& input_tangents,
2479     std::vector<PyObject*>* output_tangents) {
2480   tensorflow::Safe_PyObjectPtr py_input_tangents(
2481       TangentsAsPyTuple(input_tangents));
2482 
2483   tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
2484       op_specific_forward_function, py_input_tangents.get()));
2485   if (py_result == nullptr || PyErr_Occurred()) {
2486     return tensorflow::errors::Internal(
2487         "forward gradient function threw exceptions");
2488   }
2489   return ParseTangentOutputs(py_result.get(), output_tangents);
2490 }
2491 
2492 bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
2493   if (PyBytes_Check(op_type)) {
2494     *op_type_string = PyBytes_AsString(op_type);
2495   } else if (PyUnicode_Check(op_type)) {
2496 #if PY_MAJOR_VERSION >= 3
2497     *op_type_string = PyUnicode_AsUTF8(op_type);
2498 #else
2499     PyObject* py_str = PyUnicode_AsUTF8String(op_type);
2500     if (py_str == nullptr) {
2501       return false;
2502     }
2503     *op_type_string = PyBytes_AS_STRING(py_str);
2504     Py_DECREF(py_str);
2505 #endif
2506   } else {
2507     PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
2508     return false;
2509   }
2510   return true;
2511 }
2512 
2513 bool TapeSetRecordOperation(
2514     PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
2515     const std::vector<tensorflow::int64>& input_ids,
2516     const std::vector<tensorflow::DataType>& input_dtypes,
2517     const std::function<PyBackwardFunction*()>& backward_function_getter,
2518     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2519     const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
2520   std::vector<PyTapeTensor> output_info;
2521   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2522       output_tensors, "expected a sequence of integer tensor ids"));
2523   if (PyErr_Occurred() ||
2524       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2525     return false;
2526   }
2527   string op_type_str;
2528   if (!ParseOpTypeString(op_type, &op_type_str)) {
2529     return false;
2530   }
2531   tensorflow::uint64 max_gradient_tape_id;
2532   if (!TapeSetRecordForwardprop(
2533           op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2534           input_dtypes, backward_function_getter, backward_function_killer,
2535           forward_function, nullptr /* No special-cased jvps. */,
2536           &max_gradient_tape_id)) {
2537     return false;
2538   }
2539   TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2540                         backward_function_getter, backward_function_killer,
2541                         max_gradient_tape_id);
2542   return true;
2543 }
2544 }  // namespace
2545 
2546 PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
2547                                         PyObject* output_tensors,
2548                                         PyObject* input_tensors,
2549                                         PyObject* backward_function,
2550                                         PyObject* forward_function) {
2551   if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
2552     Py_RETURN_NONE;
2553   }
2554   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2555   if (PyErr_Occurred()) return nullptr;
2556 
2557   std::vector<tensorflow::DataType> input_dtypes =
2558       MakeTensorDtypeList(input_tensors);
2559   if (PyErr_Occurred()) return nullptr;
2560 
2561   std::function<PyBackwardFunction*()> backward_function_getter(
2562       [backward_function]() {
2563         Py_INCREF(backward_function);
2564         PyBackwardFunction* function = new PyBackwardFunction(
2565             [backward_function](PyObject* out_grads,
2566                                 const std::vector<tensorflow::int64>& unused) {
2567               return PyObject_CallObject(backward_function, out_grads);
2568             });
2569         return function;
2570       });
2571   std::function<void(PyBackwardFunction*)> backward_function_killer(
2572       [backward_function](PyBackwardFunction* py_backward_function) {
2573         Py_DECREF(backward_function);
2574         delete py_backward_function;
2575       });
2576 
2577   if (forward_function == Py_None) {
2578     if (!TapeSetRecordOperation(
2579             op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2580             backward_function_getter, backward_function_killer,
2581             nullptr /* No special-cased forward function */)) {
2582       return nullptr;
2583     }
2584   } else {
2585     tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
2586         [forward_function](const std::vector<PyObject*>& input_tangents,
2587                            std::vector<PyObject*>* output_tangents,
2588                            bool use_batch = false) {
2589           return CallOpSpecificJVPFunction(forward_function, input_tangents,
2590                                            output_tangents);
2591         });
2592     if (!TapeSetRecordOperation(
2593             op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2594             backward_function_getter, backward_function_killer,
2595             &wrapped_forward_function)) {
2596       return nullptr;
2597     }
2598   }
2599   Py_RETURN_NONE;
2600 }
2601 
2602 PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
2603     PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
2604     PyObject* backward_function, PyObject* forwardprop_output_indices) {
2605   if (!HasAccumulator() || *ThreadTapeIsStopped()) {
2606     Py_RETURN_NONE;
2607   }
2608   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2609   if (PyErr_Occurred()) return nullptr;
2610 
2611   std::vector<tensorflow::DataType> input_dtypes =
2612       MakeTensorDtypeList(input_tensors);
2613   if (PyErr_Occurred()) return nullptr;
2614 
2615   std::function<PyBackwardFunction*()> backward_function_getter(
2616       [backward_function]() {
2617         Py_INCREF(backward_function);
2618         PyBackwardFunction* function = new PyBackwardFunction(
2619             [backward_function](PyObject* out_grads,
2620                                 const std::vector<tensorflow::int64>& unused) {
2621               return PyObject_CallObject(backward_function, out_grads);
2622             });
2623         return function;
2624       });
2625   std::function<void(PyBackwardFunction*)> backward_function_killer(
2626       [backward_function](PyBackwardFunction* py_backward_function) {
2627         Py_DECREF(backward_function);
2628         delete py_backward_function;
2629       });
2630   std::vector<PyTapeTensor> output_info;
2631   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2632       output_tensors, "expected a sequence of integer tensor ids"));
2633   if (PyErr_Occurred() ||
2634       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2635     return nullptr;
2636   }
2637   string op_type_str;
2638   if (!ParseOpTypeString(op_type, &op_type_str)) {
2639     return nullptr;
2640   }
2641   tensorflow::uint64 max_gradient_tape_id;
2642   if (!TapeSetRecordForwardprop(
2643           op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2644           input_dtypes, backward_function_getter, backward_function_killer,
2645           nullptr /* no special-cased forward function */,
2646           forwardprop_output_indices, &max_gradient_tape_id)) {
2647     return nullptr;
2648   }
2649   Py_RETURN_NONE;
2650 }
2651 
2652 PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
2653                                                 PyObject* output_tensors,
2654                                                 PyObject* input_tensors,
2655                                                 PyObject* backward_function) {
2656   if (!CouldBackprop()) {
2657     Py_RETURN_NONE;
2658   }
2659   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2660   if (PyErr_Occurred()) return nullptr;
2661 
2662   std::vector<tensorflow::DataType> input_dtypes =
2663       MakeTensorDtypeList(input_tensors);
2664   if (PyErr_Occurred()) return nullptr;
2665 
2666   std::function<PyBackwardFunction*()> backward_function_getter(
2667       [backward_function]() {
2668         Py_INCREF(backward_function);
2669         PyBackwardFunction* function = new PyBackwardFunction(
2670             [backward_function](PyObject* out_grads,
2671                                 const std::vector<tensorflow::int64>& unused) {
2672               return PyObject_CallObject(backward_function, out_grads);
2673             });
2674         return function;
2675       });
2676   std::function<void(PyBackwardFunction*)> backward_function_killer(
2677       [backward_function](PyBackwardFunction* py_backward_function) {
2678         Py_DECREF(backward_function);
2679         delete py_backward_function;
2680       });
2681   std::vector<PyTapeTensor> output_info;
2682   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2683       output_tensors, "expected a sequence of integer tensor ids"));
2684   if (PyErr_Occurred() ||
2685       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2686     return nullptr;
2687   }
2688   string op_type_str;
2689   if (!ParseOpTypeString(op_type, &op_type_str)) {
2690     return nullptr;
2691   }
2692   TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2693                         backward_function_getter, backward_function_killer,
2694                         // No filtering based on relative ordering with forward
2695                         // accumulators.
2696                         std::numeric_limits<tensorflow::uint64>::max());
2697   Py_RETURN_NONE;
2698 }
2699 
2700 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
2701   for (TFE_Py_Tape* tape : *GetTapeSet()) {
2702     tape->tape->DeleteTrace(tensor_id);
2703   }
2704 }
2705 
2706 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
2707   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2708   if (seq == nullptr) {
2709     return {};
2710   }
2711   int len = PySequence_Fast_GET_SIZE(seq);
2712   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2713   std::vector<PyObject*> list(seq_array, seq_array + len);
2714   Py_DECREF(seq);
2715   return list;
2716 }
2717 
2718 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
2719                               PyObject* sources, PyObject* output_gradients,
2720                               PyObject* sources_raw,
2721                               PyObject* unconnected_gradients,
2722                               TF_Status* status) {
2723   TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
2724   if (!tape_obj->tape->IsPersistent()) {
2725     auto* tape_set = GetTapeSet();
2726     if (tape_set->find(tape_obj) != tape_set->end()) {
2727       PyErr_SetString(PyExc_RuntimeError,
2728                       "gradient() cannot be invoked within the "
2729                       "GradientTape context (i.e., while operations are being "
2730                       "recorded). Either move the call to gradient() to be "
2731                       "outside the 'with tf.GradientTape' block, or "
2732                       "use a persistent tape: "
2733                       "'with tf.GradientTape(persistent=true)'");
2734       return nullptr;
2735     }
2736   }
2737 
2738   std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
2739   if (PyErr_Occurred()) {
2740     return nullptr;
2741   }
2742   std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
2743   if (PyErr_Occurred()) {
2744     return nullptr;
2745   }
2746   tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
2747                                                           sources_vec.end());
2748 
2749   tensorflow::Safe_PyObjectPtr seq =
2750       tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
2751   int len = PySequence_Fast_GET_SIZE(seq.get());
2752   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2753   std::unordered_map<tensorflow::int64, PyTapeTensor>
2754       source_tensors_that_are_targets;
2755   for (int i = 0; i < len; ++i) {
2756     tensorflow::int64 target_id = target_vec[i];
2757     if (sources_set.find(target_id) != sources_set.end()) {
2758       auto tensor = seq_array[i];
2759       source_tensors_that_are_targets.insert(
2760           std::make_pair(target_id, TapeTensorFromTensor(tensor)));
2761     }
2762     if (PyErr_Occurred()) {
2763       return nullptr;
2764     }
2765   }
2766   if (PyErr_Occurred()) {
2767     return nullptr;
2768   }
2769 
2770   std::vector<PyObject*> outgrad_vec;
2771   if (output_gradients != Py_None) {
2772     outgrad_vec = MakeTensorList(output_gradients);
2773     if (PyErr_Occurred()) {
2774       return nullptr;
2775     }
2776     for (PyObject* tensor : outgrad_vec) {
2777       // Calling the backward function will eat a reference to the tensors in
2778       // outgrad_vec, so we need to increase their reference count.
2779       Py_INCREF(tensor);
2780     }
2781   }
2782   std::vector<PyObject*> result(sources_vec.size());
2783   status->status = tape_obj->tape->ComputeGradient(
2784       *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
2785       outgrad_vec, absl::MakeSpan(result));
2786   if (!status->status.ok()) {
2787     if (PyErr_Occurred()) {
2788       // Do not propagate the erroneous status as that would swallow the
2789       // exception which caused the problem.
2790       status->status = tensorflow::Status::OK();
2791     }
2792     return nullptr;
2793   }
2794 
2795   bool unconnected_gradients_zero =
2796       strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
2797   std::vector<PyObject*> sources_obj;
2798   if (unconnected_gradients_zero) {
2799     // Uses the "raw" sources here so it can properly make a zeros tensor even
2800     // if there are resource variables as sources.
2801     sources_obj = MakeTensorList(sources_raw);
2802   }
2803 
2804   if (!result.empty()) {
2805     PyObject* py_result = PyList_New(result.size());
2806     tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
2807     for (int i = 0; i < result.size(); ++i) {
2808       if (result[i] == nullptr) {
2809         if (unconnected_gradients_zero) {
2810           // generate a zeros tensor in the shape of sources[i]
2811           tensorflow::DataType dtype =
2812               tensorflow::PyTensor_DataType(sources_obj[i]);
2813           PyTapeTensor tensor =
2814               PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
2815           result[i] = tensor.ZerosLike();
2816         } else {
2817           Py_INCREF(Py_None);
2818           result[i] = Py_None;
2819         }
2820       } else if (seen_results.find(result[i]) != seen_results.end()) {
2821         Py_INCREF(result[i]);
2822       }
2823       seen_results.insert(result[i]);
2824       PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
2825     }
2826     return py_result;
2827   }
2828   return PyList_New(0);
2829 }
2830 
2831 PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
2832   TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
2833   if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
2834   TFE_Py_ForwardAccumulator* accumulator =
2835       PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
2836   if (py_vspace == nullptr) {
2837     MaybeRaiseExceptionFromStatus(
2838         tensorflow::errors::Internal(
2839             "ForwardAccumulator requires a PyVSpace to be registered."),
2840         nullptr);
2841   }
2842   accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
2843   return reinterpret_cast<PyObject*>(accumulator);
2844 }
2845 
2846 PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
2847   TFE_Py_ForwardAccumulator* c_accumulator(
2848       reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2849   c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
2850   if (GetAccumulatorSet()->insert(c_accumulator)) {
2851     Py_INCREF(accumulator);
2852     Py_RETURN_NONE;
2853   } else {
2854     MaybeRaiseExceptionFromStatus(
2855         tensorflow::errors::Internal(
2856             "A ForwardAccumulator was added to the active set twice."),
2857         nullptr);
2858     return nullptr;
2859   }
2860 }
2861 
2862 void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
2863   GetAccumulatorSet()->erase(
2864       reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2865   Py_DECREF(accumulator);
2866 }
2867 
2868 void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
2869                                     PyObject* tangent) {
2870   tensorflow::int64 tensor_id = FastTensorId(tensor);
2871   reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2872       ->accumulator->Watch(tensor_id, tangent);
2873   RegisterForwardAccumulatorCleanup(tensor, tensor_id);
2874 }
2875 
2876 // Returns a new reference to the JVP Tensor.
2877 PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
2878                                        PyObject* tensor) {
2879   PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2880                       ->accumulator->FetchJVP(FastTensorId(tensor));
2881   if (jvp == nullptr) {
2882     jvp = Py_None;
2883   }
2884   Py_INCREF(jvp);
2885   return jvp;
2886 }
2887 
2888 PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
2889   if (!TapeCouldPossiblyRecord(tensors)) {
2890     tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
2891     tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
2892     return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
2893   }
2894   auto accumulators = *GetAccumulatorSet();
2895   tensorflow::Safe_PyObjectPtr tensors_fast(
2896       PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
2897   if (tensors_fast == nullptr || PyErr_Occurred()) {
2898     return nullptr;
2899   }
2900   std::vector<tensorflow::int64> augmented_input_ids;
2901   int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
2902   PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
2903   for (Py_ssize_t position = 0; position < len; ++position) {
2904     PyObject* input = tensors_fast_array[position];
2905     if (input == Py_None) {
2906       continue;
2907     }
2908     tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
2909     if (input_dtype == tensorflow::DT_INVALID) {
2910       return nullptr;
2911     }
2912     augmented_input_ids.push_back(FastTensorId(input));
2913   }
2914   if (PyErr_Occurred()) {
2915     return nullptr;
2916   }
2917   // Find the innermost accumulator such that all outer accumulators are
2918   // recording. Any more deeply nested accumulators will not have their JVPs
2919   // saved.
2920   AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
2921   for (; innermost_all_recording != accumulators.end();
2922        ++innermost_all_recording) {
2923     if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
2924       break;
2925     }
2926   }
2927   AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
2928       innermost_all_recording);
2929 
2930   bool saving_jvps = false;
2931   tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
2932   std::vector<PyObject*> new_tensors;
2933   Py_ssize_t accumulator_index = 0;
2934   // Start with the innermost accumulators to give outer accumulators a chance
2935   // to find their higher-order JVPs.
2936   for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
2937        it != accumulators.rend(); ++it, ++accumulator_index) {
2938     std::vector<tensorflow::int64> new_input_ids;
2939     std::vector<std::pair<tensorflow::int64, tensorflow::int64>>
2940         accumulator_indices;
2941     if (it == reverse_innermost_all_recording) {
2942       saving_jvps = true;
2943     }
2944     if (saving_jvps) {
2945       for (int input_index = 0; input_index < augmented_input_ids.size();
2946            ++input_index) {
2947         tensorflow::int64 existing_input = augmented_input_ids[input_index];
2948         PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
2949         if (jvp != nullptr) {
2950           new_tensors.push_back(jvp);
2951           new_input_ids.push_back(FastTensorId(jvp));
2952           accumulator_indices.emplace_back(
2953               input_index,
2954               augmented_input_ids.size() + new_input_ids.size() - 1);
2955         }
2956       }
2957     }
2958     tensorflow::Safe_PyObjectPtr accumulator_indices_py(
2959         PyTuple_New(accumulator_indices.size()));
2960     for (int i = 0; i < accumulator_indices.size(); ++i) {
2961       tensorflow::Safe_PyObjectPtr from_index(
2962           GetPythonObjectFromInt(accumulator_indices[i].first));
2963       tensorflow::Safe_PyObjectPtr to_index(
2964           GetPythonObjectFromInt(accumulator_indices[i].second));
2965       PyTuple_SetItem(accumulator_indices_py.get(), i,
2966                       PyTuple_Pack(2, from_index.get(), to_index.get()));
2967     }
2968     PyTuple_SetItem(all_indices.get(), accumulator_index,
2969                     accumulator_indices_py.release());
2970     augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
2971                                new_input_ids.end());
2972   }
2973 
2974   tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
2975   for (int i = 0; i < new_tensors.size(); ++i) {
2976     PyObject* jvp = new_tensors[i];
2977     Py_INCREF(jvp);
2978     PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
2979   }
2980   return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
2981 }
2982 
2983 namespace {
2984 
2985 // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
2986 enum FastPathExecuteArgIndex {
2987   FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
2988   FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
2989   FAST_PATH_EXECUTE_ARG_NAME = 2,
2990   FAST_PATH_EXECUTE_ARG_INPUT_START = 3
2991 };
2992 
2993 PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
2994 #if PY_MAJOR_VERSION >= 3
2995   return PyUnicode_FromStringAndSize(s.data(), s.size());
2996 #else
2997   return PyBytes_FromStringAndSize(s.data(), s.size());
2998 #endif
2999 }
3000 
3001 bool CheckResourceVariable(PyObject* item) {
3002   if (tensorflow::swig::IsResourceVariable(item)) {
3003     tensorflow::Safe_PyObjectPtr handle(
3004         PyObject_GetAttrString(item, "_handle"));
3005     return EagerTensor_CheckExact(handle.get());
3006   }
3007 
3008   return false;
3009 }
3010 
3011 bool IsNumberType(PyObject* item) {
3012 #if PY_MAJOR_VERSION >= 3
3013   return PyFloat_Check(item) || PyLong_Check(item);
3014 #else
3015   return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
3016 #endif
3017 }
3018 
3019 bool CheckOneInput(PyObject* item) {
3020   if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
3021       PyArray_Check(item) || IsNumberType(item)) {
3022     return true;
3023   }
3024 
3025   // Sequences are not properly handled. Sequences with purely python numeric
3026   // types work, but sequences with mixes of EagerTensors and python numeric
3027   // types don't work.
3028   // TODO(nareshmodi): fix
3029   return false;
3030 }
3031 
3032 bool CheckInputsOk(PyObject* seq, int start_index,
3033                    const tensorflow::OpDef& op_def) {
3034   for (int i = 0; i < op_def.input_arg_size(); i++) {
3035     PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
3036     if (!op_def.input_arg(i).number_attr().empty() ||
3037         !op_def.input_arg(i).type_list_attr().empty()) {
3038       // This item should be a seq input.
3039       if (!PySequence_Check(item)) {
3040         VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3041                 << "\", Input \"" << op_def.input_arg(i).name()
3042                 << "\" since we expected a sequence, but got "
3043                 << item->ob_type->tp_name;
3044         return false;
3045       }
3046       tensorflow::Safe_PyObjectPtr fast_item(
3047           PySequence_Fast(item, "Could not parse sequence."));
3048       if (fast_item.get() == nullptr) {
3049         return false;
3050       }
3051       int len = PySequence_Fast_GET_SIZE(fast_item.get());
3052       PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3053       for (Py_ssize_t j = 0; j < len; j++) {
3054         PyObject* inner_item = fast_item_array[j];
3055         if (!CheckOneInput(inner_item)) {
3056           VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3057                   << "\", Input \"" << op_def.input_arg(i).name()
3058                   << "\", Index " << j
3059                   << " since we expected an EagerTensor/ResourceVariable, "
3060                      "but got "
3061                   << inner_item->ob_type->tp_name;
3062           return false;
3063         }
3064       }
3065     } else if (!CheckOneInput(item)) {
3066       VLOG(1)
3067           << "Falling back to slow path for Op \"" << op_def.name()
3068           << "\", Input \"" << op_def.input_arg(i).name()
3069           << "\" since we expected an EagerTensor/ResourceVariable, but got "
3070           << item->ob_type->tp_name;
3071       return false;
3072     }
3073   }
3074 
3075   return true;
3076 }
3077 
3078 tensorflow::DataType MaybeGetDType(PyObject* item) {
3079   if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
3080     return tensorflow::PyTensor_DataType(item);
3081   }
3082 
3083   return tensorflow::DT_INVALID;
3084 }
3085 
3086 tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
3087                                           FastPathOpExecInfo* op_exec_info) {
3088   auto cached_it = op_exec_info->cached_dtypes.find(attr);
3089   if (cached_it != op_exec_info->cached_dtypes.end()) {
3090     return cached_it->second;
3091   }
3092 
3093   auto it = op_exec_info->attr_to_inputs_map->find(attr);
3094   if (it == op_exec_info->attr_to_inputs_map->end()) {
3095     // No other inputs - this should never happen.
3096     return tensorflow::DT_INVALID;
3097   }
3098 
3099   for (const auto& input_info : it->second) {
3100     PyObject* item = PyTuple_GET_ITEM(
3101         op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
3102     if (input_info.is_list) {
3103       tensorflow::Safe_PyObjectPtr fast_item(
3104           PySequence_Fast(item, "Unable to allocate"));
3105       int len = PySequence_Fast_GET_SIZE(fast_item.get());
3106       PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3107       for (int i = 0; i < len; i++) {
3108         auto dtype = MaybeGetDType(fast_item_array[i]);
3109         if (dtype != tensorflow::DT_INVALID) return dtype;
3110       }
3111     } else {
3112       auto dtype = MaybeGetDType(item);
3113       if (dtype != tensorflow::DT_INVALID) return dtype;
3114     }
3115   }
3116 
3117   auto default_it = op_exec_info->default_dtypes->find(attr);
3118   if (default_it != op_exec_info->default_dtypes->end()) {
3119     return default_it->second;
3120   }
3121 
3122   return tensorflow::DT_INVALID;
3123 }
3124 
3125 PyObject* CopySequenceSettingIndicesToNull(
3126     PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
3127   tensorflow::Safe_PyObjectPtr fast_seq(
3128       PySequence_Fast(seq, "unable to allocate"));
3129   int len = PySequence_Fast_GET_SIZE(fast_seq.get());
3130   PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
3131   PyObject* result = PyTuple_New(len);
3132   for (int i = 0; i < len; i++) {
3133     PyObject* item;
3134     if (indices.find(i) != indices.end()) {
3135       item = Py_None;
3136     } else {
3137       item = fast_seq_array[i];
3138     }
3139     Py_INCREF(item);
3140     PyTuple_SET_ITEM(result, i, item);
3141   }
3142   return result;
3143 }
3144 
3145 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
3146                          PyObject* results,
3147                          PyObject* forward_pass_name_scope = nullptr) {
3148   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
3149   if (PyErr_Occurred()) return nullptr;
3150   std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
3151   if (PyErr_Occurred()) return nullptr;
3152 
3153   bool should_record = false;
3154   for (TFE_Py_Tape* tape : SafeTapeSet()) {
3155     if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
3156       should_record = true;
3157       break;
3158     }
3159   }
3160   if (!should_record) {
3161     for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
3162       if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
3163         should_record = true;
3164         break;
3165       }
3166     }
3167   }
3168   if (!should_record) Py_RETURN_NONE;
3169 
3170   string c_op_name = TFE_GetPythonString(op_name);
3171 
3172   PyObject* op_outputs;
3173   bool op_outputs_tuple_created = false;
3174 
3175   if (const auto unused_output_indices =
3176           OpGradientUnusedOutputIndices(c_op_name)) {
3177     if (unused_output_indices->empty()) {
3178       op_outputs = Py_None;
3179     } else {
3180       op_outputs_tuple_created = true;
3181       op_outputs =
3182           CopySequenceSettingIndicesToNull(results, *unused_output_indices);
3183     }
3184   } else {
3185     op_outputs = results;
3186   }
3187 
3188   PyObject* op_inputs;
3189   bool op_inputs_tuple_created = false;
3190 
3191   if (const auto unused_input_indices =
3192           OpGradientUnusedInputIndices(c_op_name)) {
3193     if (unused_input_indices->empty()) {
3194       op_inputs = Py_None;
3195     } else {
3196       op_inputs_tuple_created = true;
3197       op_inputs =
3198           CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
3199     }
3200   } else {
3201     op_inputs = inputs;
3202   }
3203 
3204   tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
3205       [op_name, attrs, inputs, results](
3206           const std::vector<PyObject*>& input_tangents,
3207           std::vector<PyObject*>* output_tangents, bool use_batch) {
3208         return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
3209                                output_tangents, use_batch);
3210       });
3211   tensorflow::eager::ForwardFunction<PyObject>* forward_function;
3212   if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
3213       c_op_name == "If" || c_op_name == "StatelessIf") {
3214     // Control flow contains non-hashable attributes. Handling them in Python is
3215     // a headache, so instead we'll stay as close to GradientTape's handling as
3216     // possible (a null forward function means the accumulator forwards to a
3217     // tape).
3218     //
3219     // This is safe to do since we'll only see control flow when graph building,
3220     // in which case we can rely on pruning.
3221     forward_function = nullptr;
3222   } else {
3223     forward_function = &py_forward_function;
3224   }
3225 
3226   PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
3227 
3228   if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
3229 
3230   TapeSetRecordOperation(
3231       op_name, inputs, results, input_ids, input_dtypes,
3232       [op_name, attrs, num_inputs, op_inputs, op_outputs,
3233        forward_pass_name_scope]() {
3234         Py_INCREF(op_name);
3235         Py_INCREF(attrs);
3236         Py_INCREF(num_inputs);
3237         Py_INCREF(op_inputs);
3238         Py_INCREF(op_outputs);
3239         Py_INCREF(forward_pass_name_scope);
3240         PyBackwardFunction* function = new PyBackwardFunction(
3241             [op_name, attrs, num_inputs, op_inputs, op_outputs,
3242              forward_pass_name_scope](
3243                 PyObject* output_grads,
3244                 const std::vector<tensorflow::int64>& unneeded_gradients) {
3245               if (PyErr_Occurred()) {
3246                 return static_cast<PyObject*>(nullptr);
3247               }
3248               tensorflow::Safe_PyObjectPtr skip_input_indices;
3249               if (!unneeded_gradients.empty()) {
3250                 skip_input_indices.reset(
3251                     PyTuple_New(unneeded_gradients.size()));
3252                 for (int i = 0; i < unneeded_gradients.size(); i++) {
3253                   PyTuple_SET_ITEM(
3254                       skip_input_indices.get(), i,
3255                       GetPythonObjectFromInt(unneeded_gradients[i]));
3256                 }
3257               } else {
3258                 Py_INCREF(Py_None);
3259                 skip_input_indices.reset(Py_None);
3260               }
3261               tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3262                   "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3263                   output_grads, skip_input_indices.get(),
3264                   forward_pass_name_scope));
3265 
3266               tensorflow::Safe_PyObjectPtr result(
3267                   PyObject_CallObject(gradient_function, callback_args.get()));
3268 
3269               if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
3270 
3271               return tensorflow::swig::Flatten(result.get());
3272             });
3273         return function;
3274       },
3275       [op_name, attrs, num_inputs, op_inputs, op_outputs,
3276        forward_pass_name_scope](PyBackwardFunction* backward_function) {
3277         Py_DECREF(op_name);
3278         Py_DECREF(attrs);
3279         Py_DECREF(num_inputs);
3280         Py_DECREF(op_inputs);
3281         Py_DECREF(op_outputs);
3282         Py_DECREF(forward_pass_name_scope);
3283 
3284         delete backward_function;
3285       },
3286       forward_function);
3287 
3288   Py_DECREF(num_inputs);
3289   if (op_outputs_tuple_created) Py_DECREF(op_outputs);
3290   if (op_inputs_tuple_created) Py_DECREF(op_inputs);
3291 
3292   if (PyErr_Occurred()) {
3293     return nullptr;
3294   }
3295 
3296   Py_RETURN_NONE;
3297 }
3298 
3299 void MaybeNotifyVariableAccessed(PyObject* input) {
3300   DCHECK(CheckResourceVariable(input));
3301   DCHECK(PyObject_HasAttrString(input, "_trainable"));
3302 
3303   tensorflow::Safe_PyObjectPtr trainable(
3304       PyObject_GetAttrString(input, "_trainable"));
3305   if (trainable.get() == Py_False) return;
3306   TFE_Py_TapeVariableAccessed(input);
3307   TFE_Py_VariableWatcherVariableAccessed(input);
3308 }
3309 
3310 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
3311                     PyObject* input, tensorflow::Safe_PyObjectPtr* output,
3312                     TF_Status* status) {
3313   MaybeNotifyVariableAccessed(input);
3314 
3315   TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
3316   auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
3317   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3318 
3319   TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
3320   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3321 
3322   // Set dtype
3323   DCHECK(PyObject_HasAttrString(input, "_dtype"));
3324   tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
3325   int value;
3326   if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
3327     return false;
3328   }
3329   TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
3330 
3331   // Get handle
3332   tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
3333   if (!EagerTensor_CheckExact(handle.get())) return false;
3334   TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
3335   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3336 
3337   int num_retvals = 1;
3338   TFE_TensorHandle* output_handle;
3339   TFE_Execute(op, &output_handle, &num_retvals, status);
3340   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3341 
3342   // Always create the py object (and correctly DECREF it) from the returned
3343   // value, else the data will leak.
3344   output->reset(EagerTensorFromHandle(output_handle));
3345 
3346   // TODO(nareshmodi): Should we run post exec callbacks here?
3347   if (parent_op_exec_info.run_gradient_callback) {
3348     tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
3349     PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
3350 
3351     tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
3352     Py_INCREF(output->get());  // stay alive after since tuple steals.
3353     PyTuple_SET_ITEM(outputs.get(), 0, output->get());
3354 
3355     tensorflow::Safe_PyObjectPtr op_string(
3356         GetPythonObjectFromString("ReadVariableOp"));
3357     if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
3358                         outputs.get())) {
3359       return false;
3360     }
3361   }
3362 
3363   return true;
3364 }
3365 
3366 // Supports 3 cases at the moment:
3367 //  i) input is an EagerTensor.
3368 //  ii) input is a ResourceVariable - in this case, the is_variable param is
3369 //  set to true.
3370 //  iii) input is an arbitrary python list/tuple (note, this handling doesn't
3371 //  support packing).
3372 //
3373 //  NOTE: dtype_hint_getter must *always* return a PyObject that can be
3374 //  decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
3375 //  increfs Py_None).
3376 //
3377 //  NOTE: This function sets a python error directly, and returns false.
3378 //  TF_Status is only passed since we don't want to have to reallocate it.
3379 bool ConvertToTensor(
3380     const FastPathOpExecInfo& op_exec_info, PyObject* input,
3381     tensorflow::Safe_PyObjectPtr* output_handle,
3382     // This gets a hint for this particular input.
3383     const std::function<tensorflow::DataType()>& dtype_hint_getter,
3384     // This sets the dtype after conversion is complete.
3385     const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
3386     TF_Status* status) {
3387   if (EagerTensor_CheckExact(input)) {
3388     Py_INCREF(input);
3389     output_handle->reset(input);
3390     return true;
3391   } else if (CheckResourceVariable(input)) {
3392     return ReadVariableOp(op_exec_info, input, output_handle, status);
3393   }
3394 
3395   // The hint comes from a supposedly similarly typed tensor.
3396   tensorflow::DataType dtype_hint = dtype_hint_getter();
3397 
3398   TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
3399       op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
3400   if (handle == nullptr) {
3401     return MaybeRaiseExceptionFromTFStatus(status, nullptr);
3402   }
3403 
3404   output_handle->reset(EagerTensorFromHandle(handle));
3405   dtype_setter(
3406       static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
3407 
3408   return true;
3409 }
3410 
3411 // Adds input and type attr to the op, and to the list of flattened
3412 // inputs/attrs.
3413 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
3414                   const bool add_type_attr,
3415                   const tensorflow::OpDef::ArgDef& input_arg,
3416                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
3417                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
3418                   TFE_Op* op, TF_Status* status) {
3419   // py_eager_tensor's ownership is transferred to flattened_inputs if it is
3420   // required, else the object is destroyed and DECREF'd when the object goes
3421   // out of scope in this function.
3422   tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
3423 
3424   if (!ConvertToTensor(
3425           *op_exec_info, input, &py_eager_tensor,
3426           [&]() {
3427             if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
3428               return input_arg.type();
3429             }
3430             return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
3431           },
3432           [&](const tensorflow::DataType dtype) {
3433             op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
3434           },
3435           status)) {
3436     return false;
3437   }
3438 
3439   TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
3440 
3441   if (add_type_attr && !input_arg.type_attr().empty()) {
3442     auto dtype = TFE_TensorHandleDataType(input_handle);
3443     TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
3444     if (flattened_attrs != nullptr) {
3445       flattened_attrs->emplace_back(
3446           GetPythonObjectFromString(input_arg.type_attr()));
3447       flattened_attrs->emplace_back(PyLong_FromLong(dtype));
3448     }
3449   }
3450 
3451   if (flattened_inputs != nullptr) {
3452     flattened_inputs->emplace_back(std::move(py_eager_tensor));
3453   }
3454 
3455   TFE_OpAddInput(op, input_handle, status);
3456   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3457     return false;
3458   }
3459 
3460   return true;
3461 }
3462 
3463 const char* GetDeviceName(PyObject* py_device_name) {
3464   if (py_device_name != Py_None) {
3465     return TFE_GetPythonString(py_device_name);
3466   }
3467   return nullptr;
3468 }
3469 
3470 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
3471   if (!PySequence_Check(seq)) {
3472     PyErr_SetString(PyExc_TypeError,
3473                     Printf("expected a sequence for attr %s, got %s instead",
3474                            attr_name.data(), seq->ob_type->tp_name)
3475                         .data());
3476 
3477     return false;
3478   }
3479   if (PyArray_Check(seq) &&
3480       PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
3481     PyErr_SetString(PyExc_ValueError,
3482                     Printf("expected a sequence for attr %s, got an ndarray "
3483                            "with rank %d instead",
3484                            attr_name.data(),
3485                            PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
3486                         .data());
3487     return false;
3488   }
3489   return true;
3490 }
3491 
3492 bool RunCallbacks(
3493     const FastPathOpExecInfo& op_exec_info, PyObject* args,
3494     int num_inferred_attrs,
3495     const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
3496     const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
3497     PyObject* flattened_result) {
3498   DCHECK(op_exec_info.run_callbacks);
3499 
3500   tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
3501   for (int i = 0; i < flattened_inputs.size(); i++) {
3502     PyObject* input = flattened_inputs[i].get();
3503     Py_INCREF(input);
3504     PyTuple_SET_ITEM(inputs.get(), i, input);
3505   }
3506 
3507   int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
3508   int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
3509   tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
3510 
3511   for (int i = 0; i < num_non_inferred_attrs; i++) {
3512     auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
3513     Py_INCREF(attr);
3514     PyTuple_SET_ITEM(attrs.get(), i, attr);
3515   }
3516 
3517   for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
3518     PyObject* attr_or_name =
3519         flattened_attrs.at(i - num_non_inferred_attrs).get();
3520     Py_INCREF(attr_or_name);
3521     PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
3522   }
3523 
3524   if (op_exec_info.run_gradient_callback) {
3525     if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
3526                         flattened_result)) {
3527       return false;
3528     }
3529   }
3530 
3531   if (op_exec_info.run_post_exec_callbacks) {
3532     tensorflow::Safe_PyObjectPtr callback_args(
3533         Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
3534                       flattened_result, op_exec_info.name));
3535     for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
3536       PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
3537       if (!PyCallable_Check(callback_fn)) {
3538         PyErr_SetString(
3539             PyExc_TypeError,
3540             Printf("expected a function for "
3541                    "post execution callback in index %ld, got %s instead",
3542                    i, callback_fn->ob_type->tp_name)
3543                 .c_str());
3544         return false;
3545       }
3546       PyObject* callback_result =
3547           PyObject_CallObject(callback_fn, callback_args.get());
3548       if (!callback_result) {
3549         return false;
3550       }
3551       Py_DECREF(callback_result);
3552     }
3553   }
3554 
3555   return true;
3556 }
3557 
3558 }  // namespace
3559 
3560 PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
3561   tensorflow::profiler::TraceMe activity(
3562       "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
3563   Py_ssize_t args_size = PyTuple_GET_SIZE(args);
3564   if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
3565     PyErr_SetString(
3566         PyExc_ValueError,
3567         Printf("There must be at least %d items in the input tuple.",
3568                FAST_PATH_EXECUTE_ARG_INPUT_START)
3569             .c_str());
3570     return nullptr;
3571   }
3572 
3573   FastPathOpExecInfo op_exec_info;
3574 
3575   PyObject* py_eager_context =
3576       PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
3577 
3578   // TODO(edoper): Use interned string here
3579   PyObject* eager_context_handle =
3580       PyObject_GetAttrString(py_eager_context, "_context_handle");
3581 
3582   TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
3583       PyCapsule_GetPointer(eager_context_handle, nullptr));
3584   op_exec_info.ctx = ctx;
3585   op_exec_info.args = args;
3586 
3587   if (ctx == nullptr) {
3588     // The context hasn't been initialized. It will be in the slow path.
3589     RaiseFallbackException(
3590         "This function does not handle the case of the path where "
3591         "all inputs are not already EagerTensors.");
3592     return nullptr;
3593   }
3594 
3595   auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
3596   if (tld == nullptr) {
3597     return nullptr;
3598   }
3599   op_exec_info.device_name = GetDeviceName(tld->device_name.get());
3600   op_exec_info.callbacks = tld->op_callbacks.get();
3601 
3602   op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
3603   op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
3604 
3605   // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
3606   // (similar to benchmark_tf_gradient_function_*). Also consider using an
3607   // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
3608   // point out problems with heap allocs.
3609   op_exec_info.run_gradient_callback =
3610       !*ThreadTapeIsStopped() && HasAccumulatorOrTape();
3611   op_exec_info.run_post_exec_callbacks =
3612       op_exec_info.callbacks != Py_None &&
3613       PyList_Size(op_exec_info.callbacks) > 0;
3614   op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
3615                                op_exec_info.run_post_exec_callbacks;
3616 
3617   TF_Status* status = GetStatus();
3618   const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
3619   if (op_name == nullptr) {
3620     PyErr_SetString(PyExc_TypeError,
3621                     Printf("expected a string for op_name, got %s instead",
3622                            op_exec_info.op_name->ob_type->tp_name)
3623                         .c_str());
3624     return nullptr;
3625   }
3626 
3627   TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
3628 
3629   auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
3630     ReturnStatus(status);
3631     ReturnOp(ctx, op);
3632   });
3633 
3634   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3635     return nullptr;
3636   }
3637 
3638   tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
3639       tensorflow::StackTrace::kStackTraceInitialSize));
3640 
3641   const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
3642   if (op_def == nullptr) return nullptr;
3643 
3644   if (args_size <
3645       FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
3646     PyErr_SetString(
3647         PyExc_ValueError,
3648         Printf("Tuple size smaller than intended. Expected to be at least %d, "
3649                "was %ld",
3650                FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3651                args_size)
3652             .c_str());
3653     return nullptr;
3654   }
3655 
3656   if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
3657     RaiseFallbackException(
3658         "This function does not handle the case of the path where "
3659         "all inputs are not already EagerTensors.");
3660     return nullptr;
3661   }
3662 
3663   op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
3664   op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
3665 
3666   // Mapping of attr name to size - used to calculate the number of values
3667   // to be expected by the TFE_Execute run.
3668   tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
3669 
3670   // Set non-inferred attrs, including setting defaults if the attr is passed in
3671   // as None.
3672   for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
3673        i < args_size; i += 2) {
3674     PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
3675     const char* attr_name = TFE_GetPythonString(py_attr_name);
3676     PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
3677 
3678     // Not creating an index since most of the time there are not more than a
3679     // few attrs.
3680     // TODO(nareshmodi): Maybe include the index as part of the
3681     // OpRegistrationData.
3682     for (const auto& attr : op_def->attr()) {
3683       if (tensorflow::StringPiece(attr_name) == attr.name()) {
3684         SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
3685                               &attr_list_sizes, status);
3686 
3687         if (!status->status.ok()) {
3688           VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
3689                   << "\" since we are unable to set the value for attr \""
3690                   << attr.name() << "\" due to: " << TF_Message(status);
3691           RaiseFallbackException(TF_Message(status));
3692           return nullptr;
3693         }
3694 
3695         break;
3696       }
3697     }
3698   }
3699 
3700   // Flat attrs and inputs as required by the record_gradient call. The attrs
3701   // here only contain inferred attrs (non-inferred attrs are added directly
3702   // from the input args).
3703   // All items in flattened_attrs and flattened_inputs contain
3704   // Safe_PyObjectPtr - any time something steals a reference to this, it must
3705   // INCREF.
3706   // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
3707   // directly.
3708   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
3709       nullptr;
3710   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
3711       nullptr;
3712 
3713   // TODO(nareshmodi): Encapsulate callbacks information into a struct.
3714   if (op_exec_info.run_callbacks) {
3715     flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3716     flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3717   }
3718 
3719   // Add inferred attrs and inputs.
3720   // The following code might set duplicate type attrs. This will result in
3721   // the CacheKey for the generated AttrBuilder possibly differing from
3722   // those where the type attrs are correctly set. Inconsistent CacheKeys
3723   // for ops means that there might be unnecessarily duplicated kernels.
3724   // TODO(nareshmodi): Fix this.
3725   for (int i = 0; i < op_def->input_arg_size(); i++) {
3726     const auto& input_arg = op_def->input_arg(i);
3727 
3728     PyObject* input =
3729         PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
3730     if (!input_arg.number_attr().empty()) {
3731       // The item is a homogeneous list.
3732       if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
3733       tensorflow::Safe_PyObjectPtr fast_input(
3734           PySequence_Fast(input, "Could not parse sequence."));
3735       if (fast_input.get() == nullptr) {
3736         return nullptr;
3737       }
3738       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3739       PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3740 
3741       TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
3742       if (op_exec_info.run_callbacks) {
3743         flattened_attrs->emplace_back(
3744             GetPythonObjectFromString(input_arg.number_attr()));
3745         flattened_attrs->emplace_back(PyLong_FromLong(len));
3746       }
3747       attr_list_sizes[input_arg.number_attr()] = len;
3748 
3749       if (len > 0) {
3750         // First item adds the type attr.
3751         if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
3752                           flattened_attrs.get(), flattened_inputs.get(), op,
3753                           status)) {
3754           return nullptr;
3755         }
3756 
3757         for (Py_ssize_t j = 1; j < len; j++) {
3758           // Since the list is homogeneous, we don't need to re-add the attr.
3759           if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
3760                             input_arg, nullptr /* flattened_attrs */,
3761                             flattened_inputs.get(), op, status)) {
3762             return nullptr;
3763           }
3764         }
3765       }
3766     } else if (!input_arg.type_list_attr().empty()) {
3767       // The item is a heterogeneous list.
3768       if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
3769         return nullptr;
3770       }
3771       tensorflow::Safe_PyObjectPtr fast_input(
3772           PySequence_Fast(input, "Could not parse sequence."));
3773       if (fast_input.get() == nullptr) {
3774         return nullptr;
3775       }
3776       const string& attr_name = input_arg.type_list_attr();
3777       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3778       PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3779       tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
3780       PyObject* py_attr_value = nullptr;
3781       if (op_exec_info.run_callbacks) {
3782         py_attr_value = PyTuple_New(len);
3783       }
3784       for (Py_ssize_t j = 0; j < len; j++) {
3785         PyObject* py_input = fast_input_array[j];
3786         tensorflow::Safe_PyObjectPtr py_eager_tensor;
3787         if (!ConvertToTensor(
3788                 op_exec_info, py_input, &py_eager_tensor,
3789                 []() { return tensorflow::DT_INVALID; },
3790                 [](const tensorflow::DataType dtype) {}, status)) {
3791           return nullptr;
3792         }
3793 
3794         TFE_TensorHandle* input_handle =
3795             EagerTensor_Handle(py_eager_tensor.get());
3796 
3797         attr_value[j] = TFE_TensorHandleDataType(input_handle);
3798 
3799         TFE_OpAddInput(op, input_handle, status);
3800         if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3801           return nullptr;
3802         }
3803 
3804         if (op_exec_info.run_callbacks) {
3805           flattened_inputs->emplace_back(std::move(py_eager_tensor));
3806 
3807           PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
3808         }
3809       }
3810       if (op_exec_info.run_callbacks) {
3811         flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
3812         flattened_attrs->emplace_back(py_attr_value);
3813       }
3814       TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
3815                             attr_value.size());
3816       attr_list_sizes[attr_name] = len;
3817     } else {
3818       // The item is a single item.
3819       if (!AddInputToOp(&op_exec_info, input, true, input_arg,
3820                         flattened_attrs.get(), flattened_inputs.get(), op,
3821                         status)) {
3822         return nullptr;
3823       }
3824     }
3825   }
3826 
3827   int64_t num_outputs = 0;
3828   for (int i = 0; i < op_def->output_arg_size(); i++) {
3829     const auto& output_arg = op_def->output_arg(i);
3830     int64_t delta = 1;
3831     if (!output_arg.number_attr().empty()) {
3832       delta = attr_list_sizes[output_arg.number_attr()];
3833     } else if (!output_arg.type_list_attr().empty()) {
3834       delta = attr_list_sizes[output_arg.type_list_attr()];
3835     }
3836     if (delta < 0) {
3837       RaiseFallbackException(
3838           "Attributes suggest that the size of an output list is less than 0");
3839       return nullptr;
3840     }
3841     num_outputs += delta;
3842   }
3843 
3844   // If number of retvals is larger than int32, we error out.
3845   if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
3846     PyErr_SetString(
3847         PyExc_ValueError,
3848         Printf("Number of outputs is too big: %ld", num_outputs).c_str());
3849     return nullptr;
3850   }
3851   int num_retvals = num_outputs;
3852 
3853   tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
3854 
3855   Py_BEGIN_ALLOW_THREADS;
3856   TFE_Execute(op, retvals.data(), &num_retvals, status);
3857   Py_END_ALLOW_THREADS;
3858 
3859   if (!status->status.ok()) {
3860     // Augment the status with the op_name for easier debugging similar to
3861     // TFE_Py_Execute.
3862     std::vector<tensorflow::StackFrame> stack_trace =
3863         status->status.stack_trace();
3864     status->status = tensorflow::Status(
3865         status->status.code(),
3866         tensorflow::strings::StrCat(
3867             TF_Message(status),
3868             " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]"),
3869         std::move(stack_trace));
3870 
3871     MaybeRaiseExceptionFromTFStatus(status, nullptr);
3872     return nullptr;
3873   }
3874 
3875   tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
3876   for (int i = 0; i < num_retvals; ++i) {
3877     PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
3878   }
3879 
3880   if (op_exec_info.run_callbacks) {
3881     if (!RunCallbacks(
3882             op_exec_info, args,
3883             FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3884             *flattened_inputs, *flattened_attrs, flat_result.get())) {
3885       return nullptr;
3886     }
3887   }
3888 
3889   // Unflatten results.
3890   if (op_def->output_arg_size() == 0) {
3891     Py_RETURN_NONE;
3892   }
3893 
3894   if (op_def->output_arg_size() == 1) {
3895     if (!op_def->output_arg(0).number_attr().empty() ||
3896         !op_def->output_arg(0).type_list_attr().empty()) {
3897       return flat_result.release();
3898     } else {
3899       auto* result = PyList_GET_ITEM(flat_result.get(), 0);
3900       Py_INCREF(result);
3901       return result;
3902     }
3903   }
3904 
3905   // Correctly output the results that are made into a namedtuple.
3906   PyObject* result = PyList_New(op_def->output_arg_size());
3907   int flat_result_index = 0;
3908   for (int i = 0; i < op_def->output_arg_size(); i++) {
3909     if (!op_def->output_arg(i).number_attr().empty()) {
3910       int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
3911       PyObject* inner_list = PyList_New(list_length);
3912       for (int j = 0; j < list_length; j++) {
3913         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3914         Py_INCREF(obj);
3915         PyList_SET_ITEM(inner_list, j, obj);
3916       }
3917       PyList_SET_ITEM(result, i, inner_list);
3918     } else if (!op_def->output_arg(i).type_list_attr().empty()) {
3919       int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
3920       PyObject* inner_list = PyList_New(list_length);
3921       for (int j = 0; j < list_length; j++) {
3922         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3923         Py_INCREF(obj);
3924         PyList_SET_ITEM(inner_list, j, obj);
3925       }
3926       PyList_SET_ITEM(result, i, inner_list);
3927     } else {
3928       PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3929       Py_INCREF(obj);
3930       PyList_SET_ITEM(result, i, obj);
3931     }
3932   }
3933   return result;
3934 }
3935 
3936 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
3937                                 PyObject* attrs, PyObject* results,
3938                                 PyObject* forward_pass_name_scope) {
3939   if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
3940     Py_RETURN_NONE;
3941   }
3942 
3943   return RecordGradient(op_name, inputs, attrs, results,
3944                         forward_pass_name_scope);
3945 }
3946 
3947 namespace {
3948 const char kTensor[] = "T";
3949 const char kList[] = "L";
3950 const char kListEnd[] = "l";
3951 const char kTuple[] = "U";
3952 const char kTupleEnd[] = "u";
3953 const char kDict[] = "D";
3954 const char kRaw[] = "R";
3955 const char kShape[] = "s";
3956 const char kShapeDelim[] = "-";
3957 const char kDType[] = "d";
3958 const char kNone[] = "n";
3959 const char kCompositeTensor[] = "C";
3960 const char kAttrs[] = "A";
3961 const char kAttrsEnd[] = "a";
3962 
3963 struct EncodeResult {
3964   string str;
3965   std::vector<PyObject*> objects;
3966 
3967   PyObject* ToPyTuple() {
3968     PyObject* result = PyTuple_New(2);
3969 
3970     PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str));
3971 
3972     if (objects.empty()) {
3973       Py_INCREF(Py_None);
3974       PyTuple_SET_ITEM(result, 1, Py_None);
3975     } else {
3976       PyObject* objects_tuple = PyTuple_New(objects.size());
3977 
3978       for (int i = 0; i < objects.size(); i++) {
3979         PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
3980       }
3981 
3982       PyTuple_SET_ITEM(result, 1, objects_tuple);
3983     }
3984 
3985     return result;
3986   }
3987 };
3988 
3989 tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
3990                                        bool include_tensor_ranks_only,
3991                                        EncodeResult* result) {
3992   if (EagerTensor_CheckExact(arg)) {
3993     tensorflow::ImmediateExecutionTensorHandle* handle =
3994         tensorflow::unwrap(EagerTensor_Handle(arg));
3995 
3996     absl::StrAppend(&result->str, kDType,
3997                     static_cast<tensorflow::DataType>(handle->DataType()));
3998     absl::StrAppend(&result->str, kShape);
3999 
4000     int num_dims;
4001     tensorflow::Status status = handle->NumDims(&num_dims);
4002     if (!status.ok()) return status;
4003 
4004     if (include_tensor_ranks_only) {
4005       absl::StrAppend(&result->str, num_dims);
4006     } else {
4007       for (int i = 0; i < num_dims; ++i) {
4008         tensorflow::int64 dim_size;
4009         status = handle->Dim(i, &dim_size);
4010         if (!status.ok()) return status;
4011         absl::StrAppend(&result->str, dim_size, kShapeDelim);
4012       }
4013     }
4014     return tensorflow::Status::OK();
4015   }
4016 
4017   tensorflow::Safe_PyObjectPtr dtype_object(
4018       PyObject_GetAttrString(arg, "dtype"));
4019 
4020   if (dtype_object == nullptr) {
4021     return tensorflow::errors::InvalidArgument(
4022         "ops.Tensor object doesn't have dtype() attr.");
4023   }
4024 
4025   tensorflow::Safe_PyObjectPtr dtype_enum(
4026       PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
4027 
4028   if (dtype_enum == nullptr) {
4029     return tensorflow::errors::InvalidArgument(
4030         "ops.Tensor's dtype object doesn't have _type_enum() attr.");
4031   }
4032 
4033   tensorflow::DataType dtype =
4034       static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
4035 
4036   absl::StrAppend(&result->str, kDType, dtype);
4037 
4038   static char _shape_tuple[] = "_shape_tuple";
4039   tensorflow::Safe_PyObjectPtr shape_tuple(
4040       PyObject_CallMethod(arg, _shape_tuple, nullptr));
4041 
4042   if (shape_tuple == nullptr) {
4043     return tensorflow::errors::InvalidArgument(
4044         "ops.Tensor object doesn't have _shape_tuple() method.");
4045   }
4046 
4047   if (shape_tuple.get() == Py_None) {
4048     // Unknown shape, encode that directly.
4049     absl::StrAppend(&result->str, kNone);
4050     return tensorflow::Status::OK();
4051   }
4052 
4053   absl::StrAppend(&result->str, kShape);
4054   tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
4055       shape_tuple.get(), "shape_tuple didn't return a sequence"));
4056 
4057   int len = PySequence_Fast_GET_SIZE(shape_seq.get());
4058   PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get());
4059 
4060   if (include_tensor_ranks_only) {
4061     absl::StrAppend(&result->str, len);
4062   } else {
4063     for (int i = 0; i < len; ++i) {
4064       PyObject* item = shape_seq_array[i];
4065       if (item == Py_None) {
4066         absl::StrAppend(&result->str, kNone);
4067       } else {
4068         absl::StrAppend(&result->str, MakeInt(item));
4069       }
4070     }
4071   }
4072   return tensorflow::Status::OK();
4073 }
4074 
4075 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
4076                                           bool include_tensor_ranks_only,
4077                                           EncodeResult* result);
4078 
4079 // This function doesn't set the type of sequence before
4080 tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
4081                                          const char* end_type,
4082                                          bool include_tensor_ranks_only,
4083                                          EncodeResult* result) {
4084   tensorflow::Safe_PyObjectPtr arg_seq(
4085       PySequence_Fast(arg, "unable to create seq from list/tuple"));
4086 
4087   absl::StrAppend(&result->str, type);
4088   int len = PySequence_Fast_GET_SIZE(arg_seq.get());
4089   PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get());
4090   for (int i = 0; i < len; ++i) {
4091     PyObject* item = arg_seq_array[i];
4092     if (item == Py_None) {
4093       absl::StrAppend(&result->str, kNone);
4094     } else {
4095       TF_RETURN_IF_ERROR(
4096           TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
4097     }
4098   }
4099   absl::StrAppend(&result->str, end_type);
4100 
4101   return tensorflow::Status::OK();
4102 }
4103 
4104 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
4105                                           bool include_tensor_ranks_only,
4106                                           EncodeResult* result) {
4107   if (tensorflow::swig::IsTensor(arg)) {
4108     absl::StrAppend(&result->str, kTensor);
4109     TF_RETURN_IF_ERROR(
4110         TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
4111   } else if (PyList_Check(arg)) {
4112     TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
4113         arg, kList, kListEnd, include_tensor_ranks_only, result));
4114   } else if (tensorflow::swig::IsTuple(arg)) {
4115     TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
4116         arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
4117   } else if (tensorflow::swig::IsMapping(arg)) {
4118     tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
4119     if (PyList_Sort(keys.get()) == -1) {
4120       return tensorflow::errors::Internal("Unable to sort keys");
4121     }
4122 
4123     absl::StrAppend(&result->str, kDict);
4124     int len = PyList_Size(keys.get());
4125 
4126     for (int i = 0; i < len; i++) {
4127       PyObject* key = PyList_GetItem(keys.get(), i);
4128       TF_RETURN_IF_ERROR(
4129           TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
4130       tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
4131       TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
4132           value.get(), include_tensor_ranks_only, result));
4133     }
4134   } else if (tensorflow::swig::IsCompositeTensor(arg)) {
4135     absl::StrAppend(&result->str, kCompositeTensor);
4136 
4137     // Add the typespec to the list of objects.  (Do *not* use a weakref,
4138     // since the type spec is often a temporary object.)
4139     PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
4140     if (type_spec == nullptr) {
4141       return tensorflow::errors::InvalidArgument(
4142           "Error while reading CompositeTensor._type_spec.");
4143     }
4144     result->objects.push_back(type_spec);
4145   } else if (tensorflow::swig::IsTypeSpec(arg)) {
4146     // Add the typespec (not a weakref) in case it's a temporary object.
4147     absl::StrAppend(&result->str, kRaw);
4148     Py_INCREF(arg);
4149     result->objects.push_back(arg);
4150   } else if (tensorflow::swig::IsAttrs(arg)) {
4151     absl::StrAppend(&result->str, kAttrs);
4152     tensorflow::Safe_PyObjectPtr attrs(
4153         PyObject_GetAttrString(arg, "__attrs_attrs__"));
4154     tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get()));
4155     for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item;
4156          item.reset(PyIter_Next(iter.get()))) {
4157       tensorflow::Safe_PyObjectPtr name(
4158           PyObject_GetAttrString(item.get(), "name"));
4159       tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get()));
4160       TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
4161           attr_arg.get(), include_tensor_ranks_only, result));
4162     }
4163     absl::StrAppend(&result->str, kAttrsEnd);
4164   } else {
4165     PyObject* object = PyWeakref_NewRef(arg, nullptr);
4166 
4167     if (object == nullptr) {
4168       PyErr_Clear();
4169 
4170       object = arg;
4171       Py_INCREF(object);
4172     }
4173 
4174     absl::StrAppend(&result->str, kRaw);
4175     result->objects.push_back(object);
4176   }
4177 
4178   return tensorflow::Status::OK();
4179 }
4180 
4181 }  // namespace
4182 
4183 // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
4184 // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
4185 // are used for both performance reasons, as much TensorFlow code specializes
4186 // on known shapes to produce slimmer graphs, and correctness, as some
4187 // high-level APIs require shapes to be fully-known.
4188 //
4189 // `include_tensor_ranks_only` allows caching on arguments excluding shape info,
4190 // so that a slow path using relaxed shape can rely on a cache key that excludes
4191 // shapes.
4192 PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
4193   EncodeResult result;
4194   const auto status =
4195       TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
4196   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
4197     return nullptr;
4198   }
4199 
4200   return result.ToPyTuple();
4201 }
4202 
4203 // A method prints incoming messages directly to Python's
4204 // stdout using Python's C API. This is necessary in Jupyter notebooks
4205 // and colabs where messages to the C stdout don't go to the notebook
4206 // cell outputs, but calls to Python's stdout do.
4207 void PrintToPythonStdout(const char* msg) {
4208   if (Py_IsInitialized()) {
4209     PyGILState_STATE py_threadstate;
4210     py_threadstate = PyGILState_Ensure();
4211 
4212     string string_msg = msg;
4213     // PySys_WriteStdout truncates strings over 1000 bytes, so
4214     // we write the message in chunks small enough to not be truncated.
4215     int CHUNK_SIZE = 900;
4216     auto len = string_msg.length();
4217     for (int i = 0; i < len; i += CHUNK_SIZE) {
4218       PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
4219     }
4220 
4221     // Force flushing to make sure print newlines aren't interleaved in
4222     // some colab environments
4223     PyRun_SimpleString("import sys; sys.stdout.flush()");
4224 
4225     PyGILState_Release(py_threadstate);
4226   }
4227 }
4228 
4229 // Register PrintToPythonStdout as a log listener, to allow
4230 // printing in colabs and jupyter notebooks to work.
4231 void TFE_Py_EnableInteractivePythonLogging() {
4232   static bool enabled_interactive_logging = false;
4233   if (!enabled_interactive_logging) {
4234     enabled_interactive_logging = true;
4235     TF_RegisterLogListener(PrintToPythonStdout);
4236   }
4237 }
4238 
4239 namespace {
4240 // weak reference to Python Context object currently active
4241 PyObject* weak_eager_context = nullptr;
4242 }  // namespace
4243 
4244 PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
4245   Py_XDECREF(weak_eager_context);
4246   weak_eager_context = PyWeakref_NewRef(py_context, nullptr);
4247   if (weak_eager_context == nullptr) {
4248     return nullptr;
4249   }
4250   Py_RETURN_NONE;
4251 }
4252 
4253 PyObject* GetPyEagerContext() {
4254   if (weak_eager_context == nullptr) {
4255     PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
4256     return nullptr;
4257   }
4258   PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context);
4259   if (py_context == Py_None) {
4260     PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed");
4261     return nullptr;
4262   }
4263   Py_INCREF(py_context);
4264   return py_context;
4265 }
4266 
4267 namespace {
4268 
4269 // Default values for thread_local_data fields.
4270 struct EagerContextThreadLocalDataDefaults {
4271   tensorflow::Safe_PyObjectPtr is_eager;
4272   tensorflow::Safe_PyObjectPtr device_spec;
4273 };
4274 
4275 // Maps each py_eager_context object to its thread_local_data.
4276 //
4277 // Note: we need to use the python Context object as the key here (and not
4278 // its handle object), because the handle object isn't created until the
4279 // context is initialized; but thread_local_data is potentially accessed
4280 // before then.
4281 using EagerContextThreadLocalDataMap = absl::flat_hash_map<
4282     PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
4283 thread_local EagerContextThreadLocalDataMap*
4284     eager_context_thread_local_data_map = nullptr;
4285 
4286 // Maps each py_eager_context object to default values.
4287 using EagerContextThreadLocalDataDefaultsMap =
4288     absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
4289 EagerContextThreadLocalDataDefaultsMap*
4290     eager_context_thread_local_data_defaults = nullptr;
4291 
4292 }  // namespace
4293 
4294 namespace tensorflow {
4295 
4296 void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
4297                                      PyObject* is_eager,
4298                                      PyObject* device_spec) {
4299   DCheckPyGilState();
4300   if (eager_context_thread_local_data_defaults == nullptr) {
4301     absl::LeakCheckDisabler disabler;
4302     eager_context_thread_local_data_defaults =
4303         new EagerContextThreadLocalDataDefaultsMap();
4304   }
4305   if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
4306     PyErr_SetString(PyExc_AssertionError,
4307                     "MakeEagerContextThreadLocalData may not be called "
4308                     "twice on the same eager Context object.");
4309   }
4310 
4311   auto& defaults =
4312       (*eager_context_thread_local_data_defaults)[py_eager_context];
4313   Py_INCREF(is_eager);
4314   defaults.is_eager.reset(is_eager);
4315   Py_INCREF(device_spec);
4316   defaults.device_spec.reset(device_spec);
4317 }
4318 
4319 EagerContextThreadLocalData* GetEagerContextThreadLocalData(
4320     PyObject* py_eager_context) {
4321   if (eager_context_thread_local_data_defaults == nullptr) {
4322     PyErr_SetString(PyExc_AssertionError,
4323                     "MakeEagerContextThreadLocalData must be called "
4324                     "before GetEagerContextThreadLocalData.");
4325     return nullptr;
4326   }
4327   auto defaults =
4328       eager_context_thread_local_data_defaults->find(py_eager_context);
4329   if (defaults == eager_context_thread_local_data_defaults->end()) {
4330     PyErr_SetString(PyExc_AssertionError,
4331                     "MakeEagerContextThreadLocalData must be called "
4332                     "before GetEagerContextThreadLocalData.");
4333     return nullptr;
4334   }
4335 
4336   if (eager_context_thread_local_data_map == nullptr) {
4337     absl::LeakCheckDisabler disabler;
4338     eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
4339   }
4340   auto& thread_local_data =
4341       (*eager_context_thread_local_data_map)[py_eager_context];
4342 
4343   if (!thread_local_data) {
4344     thread_local_data.reset(new EagerContextThreadLocalData());
4345 
4346     Safe_PyObjectPtr is_eager(PyObject_CallFunctionObjArgs(
4347         defaults->second.is_eager.get(), nullptr));
4348     if (!is_eager) return nullptr;
4349     thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
4350 
4351 #if PY_MAJOR_VERSION >= 3
4352     PyObject* scope_name = PyUnicode_FromString("");
4353 #else
4354     PyObject* scope_name = PyString_FromString("");
4355 #endif
4356     thread_local_data->scope_name.reset(scope_name);
4357 
4358 #if PY_MAJOR_VERSION >= 3
4359     PyObject* device_name = PyUnicode_FromString("");
4360 #else
4361     PyObject* device_name = PyString_FromString("");
4362 #endif
4363     thread_local_data->device_name.reset(device_name);
4364 
4365     Py_INCREF(defaults->second.device_spec.get());
4366     thread_local_data->device_spec.reset(defaults->second.device_spec.get());
4367 
4368     Py_INCREF(Py_None);
4369     thread_local_data->function_call_options.reset(Py_None);
4370 
4371     Py_INCREF(Py_None);
4372     thread_local_data->executor.reset(Py_None);
4373 
4374     thread_local_data->op_callbacks.reset(PyList_New(0));
4375   }
4376   return thread_local_data.get();
4377 }
4378 
4379 void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
4380   DCheckPyGilState();
4381   if (eager_context_thread_local_data_defaults) {
4382     eager_context_thread_local_data_defaults->erase(py_eager_context);
4383   }
4384   if (eager_context_thread_local_data_map) {
4385     eager_context_thread_local_data_map->erase(py_eager_context);
4386   }
4387 }
4388 
4389 }  // namespace tensorflow
4390