• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 #include <cstring>
18 #include <unordered_map>
19 
20 #include "absl/debugging/leak_check.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_replace.h"
23 #include "absl/types/variant.h"
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/c/c_api_internal.h"
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/c/eager/tape.h"
29 #include "tensorflow/c/eager/tfe_context_internal.h"
30 #include "tensorflow/c/eager/tfe_op_internal.h"
31 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
32 #include "tensorflow/c/tf_status.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/gtl/compactptrset.h"
37 #include "tensorflow/core/lib/gtl/flatmap.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/lib/strings/stringprintf.h"
41 #include "tensorflow/core/platform/casts.h"
42 #include "tensorflow/core/platform/errors.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/protobuf.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/platform/statusor.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/profiler/lib/traceme.h"
49 #include "tensorflow/core/util/managed_stack_trace.h"
50 #include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
51 #include "tensorflow/python/eager/pywrap_tensor.h"
52 #include "tensorflow/python/eager/pywrap_tfe.h"
53 #include "tensorflow/python/lib/core/py_util.h"
54 #include "tensorflow/python/lib/core/safe_ptr.h"
55 #include "tensorflow/python/util/stack_trace.h"
56 #include "tensorflow/python/util/util.h"
57 
58 using tensorflow::Status;
59 using tensorflow::string;
60 using tensorflow::strings::Printf;
61 
62 namespace {
63 // NOTE: Items are retrieved from and returned to these unique_ptrs, and they
64 // act as arenas. This is important if the same thread requests 2 items without
65 // releasing one.
66 // The following sequence of events on the same thread will still succeed:
67 // - GetOp <- Returns existing.
68 // - GetOp <- Allocates and returns a new pointer.
69 // - ReleaseOp <- Sets the item in the unique_ptr.
70 // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
71 // This occurs when a PyFunc kernel is run. This behavior makes it safe in that
72 // case, as well as the case where python decides to reuse the underlying
73 // C++ thread in 2 python threads case.
74 struct OpDeleter {
operator ()__anon0e39441a0111::OpDeleter75   void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
76 };
77 thread_local std::unordered_map<TFE_Context*,
78                                 std::unique_ptr<TFE_Op, OpDeleter>>
79     thread_local_eager_operation_map;                             // NOLINT
80 thread_local std::unique_ptr<TF_Status> thread_local_tf_status =  // NOLINT
81     nullptr;
82 
ReleaseThreadLocalOp(TFE_Context * ctx)83 std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
84   auto it = thread_local_eager_operation_map.find(ctx);
85   if (it == thread_local_eager_operation_map.end()) {
86     return nullptr;
87   }
88   return std::move(it->second);
89 }
90 
GetOp(TFE_Context * ctx,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)91 TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
92               const char* raw_device_name, TF_Status* status) {
93   auto op = ReleaseThreadLocalOp(ctx);
94   if (!op) {
95     op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
96   }
97   status->status =
98       tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
99   if (!status->status.ok()) {
100     op.reset();
101   }
102   return op.release();
103 }
104 
ReturnOp(TFE_Context * ctx,TFE_Op * op)105 void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
106   if (op) {
107     tensorflow::unwrap(op)->Clear();
108     thread_local_eager_operation_map[ctx].reset(op);
109   }
110 }
111 
ReleaseThreadLocalStatus()112 TF_Status* ReleaseThreadLocalStatus() {
113   if (thread_local_tf_status == nullptr) {
114     return nullptr;
115   }
116   return thread_local_tf_status.release();
117 }
118 
119 struct InputInfo {
InputInfo__anon0e39441a0111::InputInfo120   InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
121 
122   int i;
123   bool is_list = false;
124 };
125 
126 // Takes in output gradients, returns input gradients.
127 typedef std::function<PyObject*(PyObject*, const std::vector<int64_t>&)>
128     PyBackwardFunction;
129 
130 using AttrToInputsMap =
131     tensorflow::gtl::FlatMap<string,
132                              tensorflow::gtl::InlinedVector<InputInfo, 4>>;
133 
GetAllAttrToInputsMaps()134 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
135   static auto* all_attr_to_input_maps =
136       new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
137   return all_attr_to_input_maps;
138 }
139 
140 // This function doesn't use a lock, since we depend on the GIL directly.
GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef & op_def)141 AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
142 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
143   DCHECK(PyGILState_Check())
144       << "This function needs to hold the GIL when called.";
145 #endif
146   auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
147   auto* output =
148       tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
149   if (output != nullptr) {
150     return output;
151   }
152 
153   std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
154 
155   // Store a list of InputIndex -> List of corresponding inputs.
156   for (int i = 0; i < op_def.input_arg_size(); i++) {
157     if (!op_def.input_arg(i).type_attr().empty()) {
158       auto it = m->find(op_def.input_arg(i).type_attr());
159       if (it == m->end()) {
160         it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
161       }
162       it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
163     }
164   }
165 
166   auto* retval = m.get();
167   (*all_attr_to_input_maps)[op_def.name()] = m.release();
168 
169   return retval;
170 }
171 
172 // This function doesn't use a lock, since we depend on the GIL directly.
173 tensorflow::gtl::FlatMap<
174     string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
GetAllAttrToDefaultsMaps()175 GetAllAttrToDefaultsMaps() {
176   static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
177       string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
178   return all_attr_to_defaults_maps;
179 }
180 
181 tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef & op_def)182 GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
183 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
184   DCHECK(PyGILState_Check())
185       << "This function needs to hold the GIL when called.";
186 #endif
187   auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
188   auto* output =
189       tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
190   if (output != nullptr) {
191     return output;
192   }
193 
194   auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
195 
196   for (const auto& attr : op_def.attr()) {
197     if (attr.type() == "type" && attr.has_default_value()) {
198       new_map->insert({attr.name(), attr.default_value().type()});
199     }
200   }
201 
202   (*all_attr_to_defaults_maps)[op_def.name()] = new_map;
203 
204   return new_map;
205 }
206 
207 struct FastPathOpExecInfo {
208   TFE_Context* ctx;
209   const char* device_name;
210 
211   bool run_callbacks;
212   bool run_post_exec_callbacks;
213   bool run_gradient_callback;
214 
215   // The op name of the main op being executed.
216   PyObject* name;
217   // The op type name of the main op being executed.
218   PyObject* op_name;
219   PyObject* callbacks;
220 
221   // All the args passed into the FastPathOpExecInfo.
222   PyObject* args;
223 
224   // DTypes can come from another input that has the same attr. So build that
225   // map.
226   const AttrToInputsMap* attr_to_inputs_map;
227   const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
228   tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
229 };
230 
231 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn)                       \
232   bool fn_name(const string& key, PyObject* py_value, TF_Status* status,     \
233                type* value) {                                                \
234     if (check_fn(py_value)) {                                                \
235       *value = static_cast<type>(parse_fn(py_value));                        \
236       return true;                                                           \
237     } else {                                                                 \
238       TF_SetStatus(status, TF_INVALID_ARGUMENT,                              \
239                    tensorflow::strings::StrCat(                              \
240                        "Expecting " #type " value for attr ", key, ", got ", \
241                        py_value->ob_type->tp_name)                           \
242                        .c_str());                                            \
243       return false;                                                          \
244     }                                                                        \
245   }
246 
247 #if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue,int,PyLong_Check,PyLong_AsLong)248 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
249 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
250 #else
251 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
252 #endif
253 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
254 #undef PARSE_VALUE
255 
256 #if PY_MAJOR_VERSION < 3
257 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
258                      int64_t* value) {
259   if (PyInt_Check(py_value)) {
260     *value = static_cast<int64_t>(PyInt_AsLong(py_value));
261     return true;
262   } else if (PyLong_Check(py_value)) {
263     *value = static_cast<int64_t>(PyLong_AsLong(py_value));
264     return true;
265   }
266   TF_SetStatus(
267       status, TF_INVALID_ARGUMENT,
268       tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
269                                   ", got ", py_value->ob_type->tp_name)
270           .c_str());
271   return false;
272 }
273 #endif
274 
TensorShapeNumDims(PyObject * value)275 Py_ssize_t TensorShapeNumDims(PyObject* value) {
276   const auto size = PySequence_Size(value);
277   if (size == -1) {
278     // TensorShape.__len__ raises an error in the scenario where the shape is an
279     // unknown, which needs to be cleared.
280     // TODO(nareshmodi): ensure that this is actually a TensorShape.
281     PyErr_Clear();
282   }
283   return size;
284 }
285 
IsInteger(PyObject * py_value)286 bool IsInteger(PyObject* py_value) {
287 #if PY_MAJOR_VERSION >= 3
288   return PyLong_Check(py_value);
289 #else
290   return PyInt_Check(py_value) || PyLong_Check(py_value);
291 #endif
292 }
293 
294 // This function considers a Dimension._value of None to be valid, and sets the
295 // value to be -1 in that case.
ParseDimensionValue(const string & key,PyObject * py_value,TF_Status * status,int64_t * value)296 bool ParseDimensionValue(const string& key, PyObject* py_value,
297                          TF_Status* status, int64_t* value) {
298   if (IsInteger(py_value)) {
299     return ParseInt64Value(key, py_value, status, value);
300   }
301 
302   tensorflow::Safe_PyObjectPtr dimension_value(
303       PyObject_GetAttrString(py_value, "_value"));
304   if (dimension_value == nullptr) {
305     PyErr_Clear();
306     TF_SetStatus(
307         status, TF_INVALID_ARGUMENT,
308         tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
309                                     ", got ", py_value->ob_type->tp_name)
310             .c_str());
311     return false;
312   }
313 
314   if (dimension_value.get() == Py_None) {
315     *value = -1;
316     return true;
317   }
318 
319   return ParseInt64Value(key, dimension_value.get(), status, value);
320 }
321 
ParseStringValue(const string & key,PyObject * py_value,TF_Status * status,tensorflow::StringPiece * value)322 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
323                       tensorflow::StringPiece* value) {
324   if (PyBytes_Check(py_value)) {
325     Py_ssize_t size = 0;
326     char* buf = nullptr;
327     if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
328     *value = tensorflow::StringPiece(buf, size);
329     return true;
330   }
331 #if PY_MAJOR_VERSION >= 3
332   if (PyUnicode_Check(py_value)) {
333     Py_ssize_t size = 0;
334     const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
335     if (buf == nullptr) return false;
336     *value = tensorflow::StringPiece(buf, size);
337     return true;
338   }
339 #endif
340   TF_SetStatus(
341       status, TF_INVALID_ARGUMENT,
342       tensorflow::strings::StrCat("Expecting a string value for attr ", key,
343                                   ", got ", py_value->ob_type->tp_name)
344           .c_str());
345   return false;
346 }
347 
ParseBoolValue(const string & key,PyObject * py_value,TF_Status * status,unsigned char * value)348 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
349                     unsigned char* value) {
350   if (PyBool_Check(py_value)) {
351     *value = PyObject_IsTrue(py_value);
352     return true;
353   }
354   TF_SetStatus(
355       status, TF_INVALID_ARGUMENT,
356       tensorflow::strings::StrCat("Expecting bool value for attr ", key,
357                                   ", got ", py_value->ob_type->tp_name)
358           .c_str());
359   return false;
360 }
361 
362 // The passed in py_value is expected to be an object of the python type
363 // dtypes.DType or an int.
ParseTypeValue(const string & key,PyObject * py_value,TF_Status * status,int * value)364 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
365                     int* value) {
366   if (IsInteger(py_value)) {
367     return ParseIntValue(key, py_value, status, value);
368   }
369 
370   tensorflow::Safe_PyObjectPtr py_type_enum(
371       PyObject_GetAttrString(py_value, "_type_enum"));
372   if (py_type_enum == nullptr) {
373     PyErr_Clear();
374     TF_SetStatus(
375         status, TF_INVALID_ARGUMENT,
376         tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
377                                     ", got ", py_value->ob_type->tp_name)
378             .c_str());
379     return false;
380   }
381 
382   return ParseIntValue(key, py_type_enum.get(), status, value);
383 }
384 
SetOpAttrList(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_list,TF_AttrType type,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)385 bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key,
386                    PyObject* py_list, TF_AttrType type,
387                    tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
388                    TF_Status* status) {
389   if (!PySequence_Check(py_list)) {
390     TF_SetStatus(
391         status, TF_INVALID_ARGUMENT,
392         tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
393                                     ", got ", py_list->ob_type->tp_name)
394             .c_str());
395     return false;
396   }
397   const int num_values = PySequence_Size(py_list);
398   if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
399 
400 #define PARSE_LIST(c_type, parse_fn)                                      \
401   std::unique_ptr<c_type[]> values(new c_type[num_values]);               \
402   for (int i = 0; i < num_values; ++i) {                                  \
403     tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));   \
404     if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
405   }
406 
407   if (type == TF_ATTR_STRING) {
408     std::unique_ptr<const void*[]> values(new const void*[num_values]);
409     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
410     for (int i = 0; i < num_values; ++i) {
411       tensorflow::StringPiece value;
412       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
413       if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
414       values[i] = value.data();
415       lengths[i] = value.size();
416     }
417     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
418   } else if (type == TF_ATTR_INT) {
419     PARSE_LIST(int64_t, ParseInt64Value);
420     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
421   } else if (type == TF_ATTR_FLOAT) {
422     PARSE_LIST(float, ParseFloatValue);
423     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
424   } else if (type == TF_ATTR_BOOL) {
425     PARSE_LIST(unsigned char, ParseBoolValue);
426     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
427   } else if (type == TF_ATTR_TYPE) {
428     PARSE_LIST(int, ParseTypeValue);
429     TFE_OpSetAttrTypeList(op, key,
430                           reinterpret_cast<const TF_DataType*>(values.get()),
431                           num_values);
432   } else if (type == TF_ATTR_SHAPE) {
433     // Make one pass through the input counting the total number of
434     // dims across all the input lists.
435     int total_dims = 0;
436     for (int i = 0; i < num_values; ++i) {
437       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
438       if (py_value.get() != Py_None) {
439         if (!PySequence_Check(py_value.get())) {
440           TF_SetStatus(
441               status, TF_INVALID_ARGUMENT,
442               tensorflow::strings::StrCat(
443                   "Expecting None or sequence value for element", i,
444                   " of attr ", key, ", got ", py_value->ob_type->tp_name)
445                   .c_str());
446           return false;
447         }
448         const auto size = TensorShapeNumDims(py_value.get());
449         if (size >= 0) {
450           total_dims += size;
451         }
452       }
453     }
454     // Allocate a buffer that can fit all of the dims together.
455     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
456     // Copy the input dims into the buffer and set dims to point to
457     // the start of each list's dims.
458     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
459     std::unique_ptr<int[]> num_dims(new int[num_values]);
460     int64_t* offset = buffer.get();
461     for (int i = 0; i < num_values; ++i) {
462       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
463       if (py_value.get() == Py_None) {
464         dims[i] = nullptr;
465         num_dims[i] = -1;
466       } else {
467         const auto size = TensorShapeNumDims(py_value.get());
468         if (size == -1) {
469           dims[i] = nullptr;
470           num_dims[i] = -1;
471           continue;
472         }
473         dims[i] = offset;
474         num_dims[i] = size;
475         for (int j = 0; j < size; ++j) {
476           tensorflow::Safe_PyObjectPtr inner_py_value(
477               PySequence_ITEM(py_value.get(), j));
478           if (inner_py_value.get() == Py_None) {
479             *offset = -1;
480           } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
481                                           offset)) {
482             return false;
483           }
484           ++offset;
485         }
486       }
487     }
488     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
489                            status);
490     if (!status->status.ok()) return false;
491   } else if (type == TF_ATTR_FUNC) {
492     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
493     for (int i = 0; i < num_values; ++i) {
494       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
495       // Allow:
496       // (1) String function name, OR
497       // (2) A Python object with a .name attribute
498       //     (A crude test for being a
499       //     tensorflow.python.framework.function._DefinedFunction)
500       //     (which is what the various "defun" or "Defun" decorators do).
501       // And in the future also allow an object that can encapsulate
502       // the function name and its attribute values.
503       tensorflow::StringPiece func_name;
504       if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
505         PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
506         if (name_attr == nullptr ||
507             !ParseStringValue(key, name_attr, status, &func_name)) {
508           TF_SetStatus(
509               status, TF_INVALID_ARGUMENT,
510               tensorflow::strings::StrCat(
511                   "unable to set function value attribute from a ",
512                   py_value.get()->ob_type->tp_name,
513                   " object. If you think this is an error, please file an "
514                   "issue at "
515                   "https://github.com/tensorflow/tensorflow/issues/new")
516                   .c_str());
517           return false;
518         }
519       }
520       funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
521       if (!status->status.ok()) return false;
522     }
523     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
524     if (!status->status.ok()) return false;
525   } else {
526     TF_SetStatus(status, TF_UNIMPLEMENTED,
527                  tensorflow::strings::StrCat("Attr ", key,
528                                              " has unhandled list type ", type)
529                      .c_str());
530     return false;
531   }
532 #undef PARSE_LIST
533   return true;
534 }
535 
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)536 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
537                 TF_Status* status) {
538   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
539   for (const auto& attr : func.attr()) {
540     if (!status->status.ok()) return nullptr;
541     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
542     if (!status->status.ok()) return nullptr;
543   }
544   return func_op;
545 }
546 
SetOpAttrListDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * key,TF_AttrType type,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)547 void SetOpAttrListDefault(
548     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
549     const char* key, TF_AttrType type,
550     tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
551     TF_Status* status) {
552   if (type == TF_ATTR_STRING) {
553     int num_values = attr.default_value().list().s_size();
554     std::unique_ptr<const void*[]> values(new const void*[num_values]);
555     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
556     (*attr_list_sizes)[key] = num_values;
557     for (int i = 0; i < num_values; i++) {
558       const string& v = attr.default_value().list().s(i);
559       values[i] = v.data();
560       lengths[i] = v.size();
561     }
562     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
563   } else if (type == TF_ATTR_INT) {
564     int num_values = attr.default_value().list().i_size();
565     std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
566     (*attr_list_sizes)[key] = num_values;
567     for (int i = 0; i < num_values; i++) {
568       values[i] = attr.default_value().list().i(i);
569     }
570     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
571   } else if (type == TF_ATTR_FLOAT) {
572     int num_values = attr.default_value().list().f_size();
573     std::unique_ptr<float[]> values(new float[num_values]);
574     (*attr_list_sizes)[key] = num_values;
575     for (int i = 0; i < num_values; i++) {
576       values[i] = attr.default_value().list().f(i);
577     }
578     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
579   } else if (type == TF_ATTR_BOOL) {
580     int num_values = attr.default_value().list().b_size();
581     std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
582     (*attr_list_sizes)[key] = num_values;
583     for (int i = 0; i < num_values; i++) {
584       values[i] = attr.default_value().list().b(i);
585     }
586     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
587   } else if (type == TF_ATTR_TYPE) {
588     int num_values = attr.default_value().list().type_size();
589     std::unique_ptr<int[]> values(new int[num_values]);
590     (*attr_list_sizes)[key] = num_values;
591     for (int i = 0; i < num_values; i++) {
592       values[i] = attr.default_value().list().type(i);
593     }
594     TFE_OpSetAttrTypeList(op, key,
595                           reinterpret_cast<const TF_DataType*>(values.get()),
596                           attr.default_value().list().type_size());
597   } else if (type == TF_ATTR_SHAPE) {
598     int num_values = attr.default_value().list().shape_size();
599     (*attr_list_sizes)[key] = num_values;
600     int total_dims = 0;
601     for (int i = 0; i < num_values; ++i) {
602       if (!attr.default_value().list().shape(i).unknown_rank()) {
603         total_dims += attr.default_value().list().shape(i).dim_size();
604       }
605     }
606     // Allocate a buffer that can fit all of the dims together.
607     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
608     // Copy the input dims into the buffer and set dims to point to
609     // the start of each list's dims.
610     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
611     std::unique_ptr<int[]> num_dims(new int[num_values]);
612     int64_t* offset = buffer.get();
613     for (int i = 0; i < num_values; ++i) {
614       const auto& shape = attr.default_value().list().shape(i);
615       if (shape.unknown_rank()) {
616         dims[i] = nullptr;
617         num_dims[i] = -1;
618       } else {
619         for (int j = 0; j < shape.dim_size(); j++) {
620           *offset = shape.dim(j).size();
621           ++offset;
622         }
623       }
624     }
625     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
626                            status);
627   } else if (type == TF_ATTR_FUNC) {
628     int num_values = attr.default_value().list().func_size();
629     (*attr_list_sizes)[key] = num_values;
630     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
631     for (int i = 0; i < num_values; i++) {
632       funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
633     }
634     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
635   } else {
636     TF_SetStatus(status, TF_UNIMPLEMENTED,
637                  "Lists of tensors are not yet implemented for default valued "
638                  "attributes for an operation.");
639   }
640 }
641 
SetOpAttrScalar(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_value,TF_AttrType type,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)642 bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
643                      PyObject* py_value, TF_AttrType type,
644                      tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
645                      TF_Status* status) {
646   if (type == TF_ATTR_STRING) {
647     tensorflow::StringPiece value;
648     if (!ParseStringValue(key, py_value, status, &value)) return false;
649     TFE_OpSetAttrString(op, key, value.data(), value.size());
650   } else if (type == TF_ATTR_INT) {
651     int64_t value;
652     if (!ParseInt64Value(key, py_value, status, &value)) return false;
653     TFE_OpSetAttrInt(op, key, value);
654     // attr_list_sizes is set for all int attributes (since at this point we are
655     // not aware if that attribute might be used to calculate the size of an
656     // output list or not).
657     if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
658   } else if (type == TF_ATTR_FLOAT) {
659     float value;
660     if (!ParseFloatValue(key, py_value, status, &value)) return false;
661     TFE_OpSetAttrFloat(op, key, value);
662   } else if (type == TF_ATTR_BOOL) {
663     unsigned char value;
664     if (!ParseBoolValue(key, py_value, status, &value)) return false;
665     TFE_OpSetAttrBool(op, key, value);
666   } else if (type == TF_ATTR_TYPE) {
667     int value;
668     if (!ParseTypeValue(key, py_value, status, &value)) return false;
669     TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
670   } else if (type == TF_ATTR_SHAPE) {
671     if (py_value == Py_None) {
672       TFE_OpSetAttrShape(op, key, nullptr, -1, status);
673     } else {
674       if (!PySequence_Check(py_value)) {
675         TF_SetStatus(status, TF_INVALID_ARGUMENT,
676                      tensorflow::strings::StrCat(
677                          "Expecting None or sequence value for attr", key,
678                          ", got ", py_value->ob_type->tp_name)
679                          .c_str());
680         return false;
681       }
682       const auto num_dims = TensorShapeNumDims(py_value);
683       if (num_dims == -1) {
684         TFE_OpSetAttrShape(op, key, nullptr, -1, status);
685         return true;
686       }
687       std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
688       for (int i = 0; i < num_dims; ++i) {
689         tensorflow::Safe_PyObjectPtr inner_py_value(
690             PySequence_ITEM(py_value, i));
691         // If an error is generated when iterating through object, we can
692         // sometimes get a nullptr.
693         if (inner_py_value.get() == Py_None) {
694           dims[i] = -1;
695         } else if (inner_py_value.get() == nullptr ||
696                    !ParseDimensionValue(key, inner_py_value.get(), status,
697                                         &dims[i])) {
698           return false;
699         }
700       }
701       TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
702     }
703     if (!status->status.ok()) return false;
704   } else if (type == TF_ATTR_FUNC) {
705     // Allow:
706     // (1) String function name, OR
707     // (2) A Python object with a .name attribute
708     //     (A crude test for being a
709     //     tensorflow.python.framework.function._DefinedFunction)
710     //     (which is what the various "defun" or "Defun" decorators do).
711     // And in the future also allow an object that can encapsulate
712     // the function name and its attribute values.
713     tensorflow::StringPiece func_name;
714     if (!ParseStringValue(key, py_value, status, &func_name)) {
715       PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
716       if (name_attr == nullptr ||
717           !ParseStringValue(key, name_attr, status, &func_name)) {
718         TF_SetStatus(
719             status, TF_INVALID_ARGUMENT,
720             tensorflow::strings::StrCat(
721                 "unable to set function value attribute from a ",
722                 py_value->ob_type->tp_name,
723                 " object. If you think this is an error, please file an issue "
724                 "at https://github.com/tensorflow/tensorflow/issues/new")
725                 .c_str());
726         return false;
727       }
728     }
729     TF_SetStatus(status, TF_OK, "");
730     TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
731   } else {
732     TF_SetStatus(
733         status, TF_UNIMPLEMENTED,
734         tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
735             .c_str());
736     return false;
737   }
738   return true;
739 }
740 
SetOpAttrScalarDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)741 void SetOpAttrScalarDefault(
742     TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
743     const char* attr_name,
744     tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
745     TF_Status* status) {
746   SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
747   if (default_value.value_case() == tensorflow::AttrValue::kI) {
748     (*attr_list_sizes)[attr_name] = default_value.i();
749   }
750 }
751 
752 // start_index is the index at which the Tuple/List attrs will start getting
753 // processed.
SetOpAttrs(TFE_Context * ctx,TFE_Op * op,PyObject * attrs,int start_index,TF_Status * out_status)754 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
755                 TF_Status* out_status) {
756   if (attrs == Py_None) return;
757   Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
758   if ((len & 1) != 0) {
759     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
760                  "Expecting attrs tuple to have even length.");
761     return;
762   }
763   // Parse attrs
764   for (Py_ssize_t i = 0; i < len; i += 2) {
765     PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
766     PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
767 #if PY_MAJOR_VERSION >= 3
768     const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
769                                             : PyUnicode_AsUTF8(py_key);
770 #else
771     const char* key = PyBytes_AsString(py_key);
772 #endif
773     unsigned char is_list = 0;
774     const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
775     if (!out_status->status.ok()) return;
776     if (is_list != 0) {
777       if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
778         return;
779     } else {
780       if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
781         return;
782     }
783   }
784 }
785 
786 // This function will set the op attrs required. If an attr has the value of
787 // None, then it will read the AttrDef to get the default value and set that
788 // instead. Any failure in this function will simply fall back to the slow
789 // path.
SetOpAttrWithDefaults(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * attr_name,PyObject * attr_value,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)790 void SetOpAttrWithDefaults(
791     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
792     const char* attr_name, PyObject* attr_value,
793     tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
794     TF_Status* status) {
795   unsigned char is_list = 0;
796   const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
797   if (!status->status.ok()) return;
798   if (attr_value == Py_None) {
799     if (is_list != 0) {
800       SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
801                            status);
802     } else {
803       SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
804                              attr_list_sizes, status);
805     }
806   } else {
807     if (is_list != 0) {
808       SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
809                     status);
810     } else {
811       SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
812                       status);
813     }
814   }
815 }
816 
GetPythonObjectFromInt(int num)817 PyObject* GetPythonObjectFromInt(int num) {
818 #if PY_MAJOR_VERSION >= 3
819   return PyLong_FromLong(num);
820 #else
821   return PyInt_FromLong(num);
822 #endif
823 }
824 
825 // Python subclass of Exception that is created on not ok Status.
826 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
827 PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
828 
829 // Python subclass of Exception that is created to signal fallback.
830 PyObject* fallback_exception_class = nullptr;
831 
832 // Python function that returns input gradients given output gradients.
833 PyObject* gradient_function = nullptr;
834 
835 // Python function that returns output gradients given input gradients.
836 PyObject* forward_gradient_function = nullptr;
837 
838 static std::atomic<int64_t> _uid;
839 
840 // This struct is responsible for marking thread_local storage as destroyed.
841 // Access to the `alive` field in already-destroyed ThreadLocalDestructionMarker
842 // is safe because it's a trivial type, so long as nobody creates a new
843 // thread_local in the space where now-destroyed marker used to be.
844 // Hopefully creating new thread_locals while destructing a thread is rare.
845 struct ThreadLocalDestructionMarker {
~ThreadLocalDestructionMarker__anon0e39441a0111::ThreadLocalDestructionMarker846   ~ThreadLocalDestructionMarker() { alive = false; }
847   bool alive = true;
848 };
849 
850 }  // namespace
851 
GetStatus()852 TF_Status* GetStatus() {
853   TF_Status* maybe_status = ReleaseThreadLocalStatus();
854   if (maybe_status) {
855     TF_SetStatus(maybe_status, TF_OK, "");
856     return maybe_status;
857   } else {
858     return TF_NewStatus();
859   }
860 }
861 
ReturnStatus(TF_Status * status)862 void ReturnStatus(TF_Status* status) {
863   TF_SetStatus(status, TF_OK, "");
864   thread_local_tf_status.reset(status);
865 }
866 
TFE_Py_Execute(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_OutputTensorHandles * outputs,TF_Status * out_status)867 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
868                     const char* op_name, TFE_InputTensorHandles* inputs,
869                     PyObject* attrs, TFE_OutputTensorHandles* outputs,
870                     TF_Status* out_status) {
871   TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
872                            /*cancellation_manager=*/nullptr, outputs,
873                            out_status);
874 }
875 
TFE_Py_ExecuteCancelable(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_CancellationManager * cancellation_manager,TFE_OutputTensorHandles * outputs,TF_Status * out_status)876 void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
877                               const char* op_name,
878                               TFE_InputTensorHandles* inputs, PyObject* attrs,
879                               TFE_CancellationManager* cancellation_manager,
880                               TFE_OutputTensorHandles* outputs,
881                               TF_Status* out_status) {
882   tensorflow::profiler::TraceMe activity(
883       "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
884 
885   TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
886 
887   auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
888   if (!out_status->status.ok()) return;
889 
890   tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
891       tensorflow::StackTrace::kStackTraceInitialSize));
892 
893   for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
894     TFE_OpAddInput(op, inputs->at(i), out_status);
895   }
896   if (cancellation_manager && out_status->status.ok()) {
897     TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
898   }
899   if (out_status->status.ok()) {
900     SetOpAttrs(ctx, op, attrs, 0, out_status);
901   }
902   Py_BEGIN_ALLOW_THREADS;
903 
904   int num_outputs = outputs->size();
905 
906   if (out_status->status.ok()) {
907     TFE_Execute(op, outputs->data(), &num_outputs, out_status);
908   }
909 
910   if (out_status->status.ok()) {
911     outputs->resize(num_outputs);
912   } else {
913     TF_SetStatus(out_status, TF_GetCode(out_status),
914                  tensorflow::strings::StrCat(TF_Message(out_status),
915                                              " [Op:", op_name, "]")
916                      .c_str());
917   }
918 
919   Py_END_ALLOW_THREADS;
920 }
921 
TFE_Py_RegisterExceptionClass(PyObject * e)922 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
923   tensorflow::mutex_lock l(exception_class_mutex);
924   if (exception_class != nullptr) {
925     Py_DECREF(exception_class);
926   }
927   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
928     exception_class = nullptr;
929     PyErr_SetString(PyExc_TypeError,
930                     "TFE_Py_RegisterExceptionClass: "
931                     "Registered class should be subclass of Exception.");
932     return nullptr;
933   }
934 
935   Py_INCREF(e);
936   exception_class = e;
937   Py_RETURN_NONE;
938 }
939 
TFE_Py_RegisterFallbackExceptionClass(PyObject * e)940 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
941   if (fallback_exception_class != nullptr) {
942     Py_DECREF(fallback_exception_class);
943   }
944   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
945     fallback_exception_class = nullptr;
946     PyErr_SetString(PyExc_TypeError,
947                     "TFE_Py_RegisterFallbackExceptionClass: "
948                     "Registered class should be subclass of Exception.");
949     return nullptr;
950   } else {
951     Py_INCREF(e);
952     fallback_exception_class = e;
953     Py_RETURN_NONE;
954   }
955 }
956 
TFE_Py_RegisterGradientFunction(PyObject * e)957 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
958   if (gradient_function != nullptr) {
959     Py_DECREF(gradient_function);
960   }
961   if (!PyCallable_Check(e)) {
962     gradient_function = nullptr;
963     PyErr_SetString(PyExc_TypeError,
964                     "TFE_Py_RegisterGradientFunction: "
965                     "Registered object should be function.");
966     return nullptr;
967   } else {
968     Py_INCREF(e);
969     gradient_function = e;
970     Py_RETURN_NONE;
971   }
972 }
973 
TFE_Py_RegisterJVPFunction(PyObject * e)974 PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
975   if (forward_gradient_function != nullptr) {
976     Py_DECREF(forward_gradient_function);
977   }
978   if (!PyCallable_Check(e)) {
979     forward_gradient_function = nullptr;
980     PyErr_SetString(PyExc_TypeError,
981                     "TFE_Py_RegisterJVPFunction: "
982                     "Registered object should be function.");
983     return nullptr;
984   } else {
985     Py_INCREF(e);
986     forward_gradient_function = e;
987     Py_RETURN_NONE;
988   }
989 }
990 
RaiseFallbackException(const char * message)991 void RaiseFallbackException(const char* message) {
992   if (fallback_exception_class != nullptr) {
993     PyErr_SetString(fallback_exception_class, message);
994     return;
995   }
996 
997   PyErr_SetString(
998       PyExc_RuntimeError,
999       tensorflow::strings::StrCat(
1000           "Fallback exception type not set, attempting to fallback due to ",
1001           message)
1002           .data());
1003 }
1004 
1005 // Format and return `status`' error message with the attached stack trace if
1006 // available. `status` must have an error.
FormatErrorStatusStackTrace(const tensorflow::Status & status)1007 std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
1008   tensorflow::DCheckPyGilState();
1009   DCHECK(!status.ok());
1010 
1011   std::vector<tensorflow::StackFrame> stack_trace =
1012       tensorflow::errors::GetStackTrace(status);
1013 
1014   if (stack_trace.empty()) return status.error_message();
1015 
1016   PyObject* linecache = PyImport_ImportModule("linecache");
1017   PyObject* getline =
1018       PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
1019   DCHECK(getline);
1020 
1021   std::ostringstream result;
1022   result << "Exception originated from\n\n";
1023 
1024   for (const tensorflow::StackFrame& stack_frame : stack_trace) {
1025     PyObject* line_str_obj = PyObject_CallFunction(
1026         getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
1027         stack_frame.line_number);
1028     tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
1029     tensorflow::str_util::RemoveWhitespaceContext(&line_str);
1030     result << "  File \"" << stack_frame.file_name << "\", line "
1031            << stack_frame.line_number << ", in " << stack_frame.function_name
1032            << '\n';
1033 
1034     if (!line_str.empty()) result << "    " << line_str << '\n';
1035     Py_XDECREF(line_str_obj);
1036   }
1037 
1038   Py_DecRef(getline);
1039   Py_DecRef(linecache);
1040 
1041   result << '\n' << status.error_message();
1042   return result.str();
1043 }
1044 
1045 namespace tensorflow {
1046 
MaybeRaiseExceptionFromTFStatus(TF_Status * status,PyObject * exception)1047 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
1048   if (status->status.ok()) return 0;
1049   const char* msg = TF_Message(status);
1050   if (exception == nullptr) {
1051     tensorflow::mutex_lock l(exception_class_mutex);
1052     if (exception_class != nullptr) {
1053       tensorflow::Safe_PyObjectPtr payloads(PyDict_New());
1054       for (const auto& payload :
1055            tensorflow::errors::GetPayloads(status->status)) {
1056         PyDict_SetItem(payloads.get(),
1057                        PyBytes_FromString(payload.first.c_str()),
1058                        PyBytes_FromString(payload.second.c_str()));
1059       }
1060       tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1061           "siO", FormatErrorStatusStackTrace(status->status).c_str(),
1062           TF_GetCode(status), payloads.get()));
1063       if (PyErr_Occurred()) {
1064         // NOTE: This hides the actual error (i.e. the reason `status` was not
1065         // TF_OK), but there is nothing we can do at this point since we can't
1066         // generate a reasonable error from the status.
1067         // Consider adding a message explaining this.
1068         return -1;
1069       }
1070       PyErr_SetObject(exception_class, val.get());
1071       return -1;
1072     } else {
1073       exception = PyExc_RuntimeError;
1074     }
1075   }
1076   // May be update already set exception.
1077   PyErr_SetString(exception, msg);
1078   return -1;
1079 }
1080 
1081 }  // namespace tensorflow
1082 
MaybeRaiseExceptionFromStatus(const tensorflow::Status & status,PyObject * exception)1083 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
1084                                   PyObject* exception) {
1085   if (status.ok()) return 0;
1086   const char* msg = status.error_message().c_str();
1087   if (exception == nullptr) {
1088     tensorflow::mutex_lock l(exception_class_mutex);
1089     if (exception_class != nullptr) {
1090       tensorflow::Safe_PyObjectPtr payloads(PyDict_New());
1091       for (const auto& element : tensorflow::errors::GetPayloads(status)) {
1092         PyDict_SetItem(payloads.get(),
1093                        PyBytes_FromString(element.first.c_str()),
1094                        PyBytes_FromString(element.second.c_str()));
1095       }
1096       tensorflow::Safe_PyObjectPtr val(
1097           Py_BuildValue("siO", FormatErrorStatusStackTrace(status).c_str(),
1098                         status.code(), payloads.get()));
1099       PyErr_SetObject(exception_class, val.get());
1100       return -1;
1101     } else {
1102       exception = PyExc_RuntimeError;
1103     }
1104   }
1105   // May be update already set exception.
1106   PyErr_SetString(exception, msg);
1107   return -1;
1108 }
1109 
TFE_GetPythonString(PyObject * o)1110 const char* TFE_GetPythonString(PyObject* o) {
1111 #if PY_MAJOR_VERSION >= 3
1112   if (PyBytes_Check(o)) {
1113     return PyBytes_AsString(o);
1114   } else {
1115     return PyUnicode_AsUTF8(o);
1116   }
1117 #else
1118   return PyBytes_AsString(o);
1119 #endif
1120 }
1121 
get_uid()1122 int64_t get_uid() { return _uid++; }
1123 
TFE_Py_UID()1124 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
1125 
TFE_DeleteContextCapsule(PyObject * context)1126 void TFE_DeleteContextCapsule(PyObject* context) {
1127   TFE_Context* ctx =
1128       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
1129   auto op = ReleaseThreadLocalOp(ctx);
1130   op.reset();
1131   TFE_DeleteContext(ctx);
1132 }
1133 
MakeInt(PyObject * integer)1134 static int64_t MakeInt(PyObject* integer) {
1135 #if PY_MAJOR_VERSION >= 3
1136   return PyLong_AsLong(integer);
1137 #else
1138   return PyInt_AsLong(integer);
1139 #endif
1140 }
1141 
FastTensorId(PyObject * tensor)1142 static int64_t FastTensorId(PyObject* tensor) {
1143   if (EagerTensor_CheckExact(tensor)) {
1144     return PyEagerTensor_ID(tensor);
1145   }
1146   PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
1147   if (id_field == nullptr) {
1148     return -1;
1149   }
1150   int64_t id = MakeInt(id_field);
1151   Py_DECREF(id_field);
1152   return id;
1153 }
1154 
1155 namespace tensorflow {
PyTensor_DataType(PyObject * tensor)1156 DataType PyTensor_DataType(PyObject* tensor) {
1157   if (EagerTensor_CheckExact(tensor)) {
1158     return PyEagerTensor_Dtype(tensor);
1159   } else {
1160 #if PY_MAJOR_VERSION < 3
1161     // Python 2.x:
1162     static PyObject* dtype_attr = PyString_InternFromString("dtype");
1163     static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
1164 #else
1165     // Python 3.x:
1166     static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
1167     static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
1168 #endif
1169     Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
1170     if (!dtype_field) {
1171       return DT_INVALID;
1172     }
1173 
1174     Safe_PyObjectPtr enum_field(
1175         PyObject_GetAttr(dtype_field.get(), type_enum_attr));
1176     if (!enum_field) {
1177       return DT_INVALID;
1178     }
1179 
1180     return static_cast<DataType>(MakeInt(enum_field.get()));
1181   }
1182 }
1183 }  // namespace tensorflow
1184 
1185 class PyTapeTensor {
1186  public:
PyTapeTensor(int64_t id,tensorflow::DataType dtype,const tensorflow::TensorShape & shape)1187   PyTapeTensor(int64_t id, tensorflow::DataType dtype,
1188                const tensorflow::TensorShape& shape)
1189       : id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(int64_t id,tensorflow::DataType dtype,PyObject * shape)1190   PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape)
1191       : id_(id), dtype_(dtype), shape_(shape) {
1192     Py_INCREF(absl::get<1>(shape_));
1193   }
PyTapeTensor(const PyTapeTensor & other)1194   PyTapeTensor(const PyTapeTensor& other) {
1195     id_ = other.id_;
1196     dtype_ = other.dtype_;
1197     shape_ = other.shape_;
1198     if (shape_.index() == 1) {
1199       Py_INCREF(absl::get<1>(shape_));
1200     }
1201   }
1202 
~PyTapeTensor()1203   ~PyTapeTensor() {
1204     if (shape_.index() == 1) {
1205       Py_DECREF(absl::get<1>(shape_));
1206     }
1207   }
1208   PyObject* GetShape() const;
GetPyDType() const1209   PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
GetID() const1210   int64_t GetID() const { return id_; }
GetDType() const1211   tensorflow::DataType GetDType() const { return dtype_; }
1212 
1213   PyObject* OnesLike() const;
1214   PyObject* ZerosLike() const;
1215 
1216  private:
1217   int64_t id_;
1218   tensorflow::DataType dtype_;
1219 
1220   // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
1221   // PyObject is the tensor itself. This is used to support tf.shape(tensor) for
1222   // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
1223   // tensors.
1224   absl::variant<tensorflow::TensorShape, PyObject*> shape_;
1225 };
1226 
1227 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
1228 
1229 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
1230                                                   PyTapeTensor> {
1231  public:
PyVSpace(PyObject * py_vspace)1232   explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
1233     Py_INCREF(py_vspace_);
1234   }
1235 
Initialize()1236   tensorflow::Status Initialize() {
1237     num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
1238     if (num_elements_ == nullptr) {
1239       return tensorflow::errors::InvalidArgument("invalid vspace");
1240     }
1241     aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
1242     if (aggregate_fn_ == nullptr) {
1243       return tensorflow::errors::InvalidArgument("invalid vspace");
1244     }
1245     zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
1246     if (zeros_fn_ == nullptr) {
1247       return tensorflow::errors::InvalidArgument("invalid vspace");
1248     }
1249     zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
1250     if (zeros_like_fn_ == nullptr) {
1251       return tensorflow::errors::InvalidArgument("invalid vspace");
1252     }
1253     ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
1254     if (ones_fn_ == nullptr) {
1255       return tensorflow::errors::InvalidArgument("invalid vspace");
1256     }
1257     ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
1258     if (ones_like_fn_ == nullptr) {
1259       return tensorflow::errors::InvalidArgument("invalid vspace");
1260     }
1261     graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
1262     if (graph_shape_fn_ == nullptr) {
1263       return tensorflow::errors::InvalidArgument("invalid vspace");
1264     }
1265     return ::tensorflow::OkStatus();
1266   }
1267 
~PyVSpace()1268   ~PyVSpace() override {
1269     Py_XDECREF(num_elements_);
1270     Py_XDECREF(aggregate_fn_);
1271     Py_XDECREF(zeros_fn_);
1272     Py_XDECREF(zeros_like_fn_);
1273     Py_XDECREF(ones_fn_);
1274     Py_XDECREF(ones_like_fn_);
1275     Py_XDECREF(graph_shape_fn_);
1276 
1277     Py_DECREF(py_vspace_);
1278   }
1279 
NumElements(PyObject * tensor) const1280   int64_t NumElements(PyObject* tensor) const final {
1281     if (EagerTensor_CheckExact(tensor)) {
1282       return PyEagerTensor_NumElements(tensor);
1283     }
1284     PyObject* arglist =
1285         Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1286     PyObject* result = PyEval_CallObject(num_elements_, arglist);
1287     Py_DECREF(arglist);
1288     if (result == nullptr) {
1289       // The caller detects whether a python exception has been raised.
1290       return -1;
1291     }
1292     int64_t r = MakeInt(result);
1293     Py_DECREF(result);
1294     return r;
1295   }
1296 
AggregateGradients(tensorflow::gtl::ArraySlice<PyObject * > gradient_tensors) const1297   PyObject* AggregateGradients(
1298       tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1299     PyObject* list = PyList_New(gradient_tensors.size());
1300     for (int i = 0; i < gradient_tensors.size(); ++i) {
1301       // Note: stealing a reference to the gradient tensors.
1302       CHECK(gradient_tensors[i] != nullptr);
1303       CHECK(gradient_tensors[i] != Py_None);
1304       PyList_SET_ITEM(list, i,
1305                       reinterpret_cast<PyObject*>(gradient_tensors[i]));
1306     }
1307     PyObject* arglist = Py_BuildValue("(O)", list);
1308     CHECK(arglist != nullptr);
1309     PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1310     Py_DECREF(arglist);
1311     Py_DECREF(list);
1312     return result;
1313   }
1314 
TensorId(PyObject * tensor) const1315   int64_t TensorId(PyObject* tensor) const final {
1316     return FastTensorId(tensor);
1317   }
1318 
MarkAsResult(PyObject * gradient) const1319   void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1320 
Ones(PyObject * shape,PyObject * dtype) const1321   PyObject* Ones(PyObject* shape, PyObject* dtype) const {
1322     if (PyErr_Occurred()) {
1323       return nullptr;
1324     }
1325     PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1326     PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1327     Py_DECREF(arg_list);
1328     return result;
1329   }
1330 
OnesLike(PyObject * tensor) const1331   PyObject* OnesLike(PyObject* tensor) const {
1332     if (PyErr_Occurred()) {
1333       return nullptr;
1334     }
1335     return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
1336   }
1337 
1338   // Builds a tensor filled with ones with the same shape and dtype as `t`.
BuildOnesLike(const PyTapeTensor & t,PyObject ** result) const1339   Status BuildOnesLike(const PyTapeTensor& t,
1340                        PyObject** result) const override {
1341     *result = t.OnesLike();
1342     return ::tensorflow::OkStatus();
1343   }
1344 
Zeros(PyObject * shape,PyObject * dtype) const1345   PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
1346     if (PyErr_Occurred()) {
1347       return nullptr;
1348     }
1349     PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1350     PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1351     Py_DECREF(arg_list);
1352     return result;
1353   }
1354 
ZerosLike(PyObject * tensor) const1355   PyObject* ZerosLike(PyObject* tensor) const {
1356     if (PyErr_Occurred()) {
1357       return nullptr;
1358     }
1359     return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
1360   }
1361 
GraphShape(PyObject * tensor) const1362   PyObject* GraphShape(PyObject* tensor) const {
1363     PyObject* arg_list = Py_BuildValue("(O)", tensor);
1364     PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1365     Py_DECREF(arg_list);
1366     return result;
1367   }
1368 
CallBackwardFunction(const string & op_type,PyBackwardFunction * backward_function,const std::vector<int64_t> & unneeded_gradients,tensorflow::gtl::ArraySlice<PyObject * > output_gradients,absl::Span<PyObject * > result) const1369   tensorflow::Status CallBackwardFunction(
1370       const string& op_type, PyBackwardFunction* backward_function,
1371       const std::vector<int64_t>& unneeded_gradients,
1372       tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1373       absl::Span<PyObject*> result) const final {
1374     PyObject* grads = PyTuple_New(output_gradients.size());
1375     for (int i = 0; i < output_gradients.size(); ++i) {
1376       if (output_gradients[i] == nullptr) {
1377         Py_INCREF(Py_None);
1378         PyTuple_SET_ITEM(grads, i, Py_None);
1379       } else {
1380         PyTuple_SET_ITEM(grads, i,
1381                          reinterpret_cast<PyObject*>(output_gradients[i]));
1382       }
1383     }
1384     PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
1385     Py_DECREF(grads);
1386     if (py_result == nullptr) {
1387       return tensorflow::errors::Internal("gradient function threw exceptions");
1388     }
1389     PyObject* seq =
1390         PySequence_Fast(py_result, "expected a sequence of gradients");
1391     if (seq == nullptr) {
1392       return tensorflow::errors::InvalidArgument(
1393           "gradient function did not return a list");
1394     }
1395     int len = PySequence_Fast_GET_SIZE(seq);
1396     if (len != result.size()) {
1397       return tensorflow::errors::Internal(
1398           "Recorded operation '", op_type,
1399           "' returned too few gradients. Expected ", result.size(),
1400           " but received ", len);
1401     }
1402     PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1403     VLOG(1) << "Gradient length is " << len;
1404     for (int i = 0; i < len; ++i) {
1405       PyObject* item = seq_array[i];
1406       if (item == Py_None) {
1407         result[i] = nullptr;
1408       } else {
1409         Py_INCREF(item);
1410         result[i] = item;
1411       }
1412     }
1413     Py_DECREF(seq);
1414     Py_DECREF(py_result);
1415     return ::tensorflow::OkStatus();
1416   }
1417 
DeleteGradient(PyObject * tensor) const1418   void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1419 
TapeTensorFromGradient(PyObject * tensor) const1420   PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
1421     return TapeTensorFromTensor(tensor);
1422   }
1423 
1424  private:
1425   PyObject* py_vspace_;
1426 
1427   PyObject* num_elements_;
1428   PyObject* aggregate_fn_;
1429   PyObject* zeros_fn_;
1430   PyObject* zeros_like_fn_;
1431   PyObject* ones_fn_;
1432   PyObject* ones_like_fn_;
1433   PyObject* graph_shape_fn_;
1434 };
1435 PyVSpace* py_vspace = nullptr;
1436 
1437 bool HasAccumulator();
1438 
TFE_Py_RegisterVSpace(PyObject * e)1439 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1440   if (py_vspace != nullptr) {
1441     if (HasAccumulator()) {
1442       // Accumulators reference py_vspace, so we can't swap it out while one is
1443       // active. This is unlikely to ever happen.
1444       MaybeRaiseExceptionFromStatus(
1445           tensorflow::errors::Internal(
1446               "Can't change the vspace implementation while a "
1447               "forward accumulator is active."),
1448           nullptr);
1449     }
1450     delete py_vspace;
1451   }
1452 
1453   py_vspace = new PyVSpace(e);
1454   auto status = py_vspace->Initialize();
1455   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1456     delete py_vspace;
1457     return nullptr;
1458   }
1459 
1460   Py_RETURN_NONE;
1461 }
1462 
GetShape() const1463 PyObject* PyTapeTensor::GetShape() const {
1464   if (shape_.index() == 0) {
1465     auto& shape = absl::get<0>(shape_);
1466     PyObject* py_shape = PyTuple_New(shape.dims());
1467     for (int i = 0; i < shape.dims(); ++i) {
1468       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1469     }
1470 
1471     return py_shape;
1472   }
1473 
1474   return py_vspace->GraphShape(absl::get<1>(shape_));
1475 }
1476 
OnesLike() const1477 PyObject* PyTapeTensor::OnesLike() const {
1478   if (shape_.index() == 1) {
1479     PyObject* tensor = absl::get<1>(shape_);
1480     return py_vspace->OnesLike(tensor);
1481   }
1482   PyObject* py_shape = GetShape();
1483   PyObject* dtype_field = GetPyDType();
1484   PyObject* result = py_vspace->Ones(py_shape, dtype_field);
1485   Py_DECREF(dtype_field);
1486   Py_DECREF(py_shape);
1487   return result;
1488 }
1489 
ZerosLike() const1490 PyObject* PyTapeTensor::ZerosLike() const {
1491   if (GetDType() == tensorflow::DT_RESOURCE) {
1492     // Gradient functions for ops which return resource tensors accept
1493     // None. This is the behavior of py_vspace->Zeros, but checking here avoids
1494     // issues with ZerosLike.
1495     Py_RETURN_NONE;
1496   }
1497   if (shape_.index() == 1) {
1498     PyObject* tensor = absl::get<1>(shape_);
1499     return py_vspace->ZerosLike(tensor);
1500   }
1501   PyObject* py_shape = GetShape();
1502   PyObject* dtype_field = GetPyDType();
1503   PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
1504   Py_DECREF(dtype_field);
1505   Py_DECREF(py_shape);
1506   return result;
1507 }
1508 
1509 // Keeps track of all variables that have been accessed during execution.
1510 class VariableWatcher {
1511  public:
VariableWatcher()1512   VariableWatcher() {}
1513 
~VariableWatcher()1514   ~VariableWatcher() {
1515     for (const IdAndVariable& v : watched_variables_) {
1516       Py_DECREF(v.variable);
1517     }
1518   }
1519 
WatchVariable(PyObject * v)1520   int64_t WatchVariable(PyObject* v) {
1521     tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1522     if (handle == nullptr) {
1523       return -1;
1524     }
1525     int64_t id = FastTensorId(handle.get());
1526 
1527     tensorflow::mutex_lock l(watched_variables_mu_);
1528     auto insert_result = watched_variables_.emplace(id, v);
1529 
1530     if (insert_result.second) {
1531       // Only increment the reference count if we aren't already watching this
1532       // variable.
1533       Py_INCREF(v);
1534     }
1535 
1536     return id;
1537   }
1538 
GetVariablesAsPyTuple()1539   PyObject* GetVariablesAsPyTuple() {
1540     tensorflow::mutex_lock l(watched_variables_mu_);
1541     PyObject* result = PyTuple_New(watched_variables_.size());
1542     Py_ssize_t pos = 0;
1543     for (const IdAndVariable& id_and_variable : watched_variables_) {
1544       PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1545       Py_INCREF(id_and_variable.variable);
1546     }
1547     return result;
1548   }
1549 
1550  private:
1551   // We store an IdAndVariable in the map since the map needs to be locked
1552   // during insert, but should not call back into python during insert to avoid
1553   // deadlocking with the GIL.
1554   struct IdAndVariable {
1555     int64_t id;
1556     PyObject* variable;
1557 
IdAndVariableVariableWatcher::IdAndVariable1558     IdAndVariable(int64_t id, PyObject* variable)
1559         : id(id), variable(variable) {}
1560   };
1561   struct CompareById {
operator ()VariableWatcher::CompareById1562     bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1563       return lhs.id < rhs.id;
1564     }
1565   };
1566 
1567   tensorflow::mutex watched_variables_mu_;
1568   std::set<IdAndVariable, CompareById> watched_variables_
1569       TF_GUARDED_BY(watched_variables_mu_);
1570 };
1571 
1572 class GradientTape
1573     : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1574                                              PyTapeTensor> {
1575  public:
GradientTape(bool persistent,bool watch_accessed_variables)1576   explicit GradientTape(bool persistent, bool watch_accessed_variables)
1577       : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1578                                         PyTapeTensor>(persistent),
1579         watch_accessed_variables_(watch_accessed_variables) {}
1580 
~GradientTape()1581   virtual ~GradientTape() {}
1582 
VariableAccessed(PyObject * v)1583   void VariableAccessed(PyObject* v) {
1584     if (watch_accessed_variables_) {
1585       WatchVariable(v);
1586     }
1587   }
1588 
WatchVariable(PyObject * v)1589   void WatchVariable(PyObject* v) {
1590     int64_t id = variable_watcher_.WatchVariable(v);
1591 
1592     if (!PyErr_Occurred()) {
1593       this->Watch(id);
1594     }
1595   }
1596 
GetVariablesAsPyTuple()1597   PyObject* GetVariablesAsPyTuple() {
1598     return variable_watcher_.GetVariablesAsPyTuple();
1599   }
1600 
1601  private:
1602   bool watch_accessed_variables_;
1603   VariableWatcher variable_watcher_;
1604 };
1605 
1606 typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
1607                                               PyTapeTensor>
1608     ForwardAccumulator;
1609 
1610 // Incremented when a GradientTape or accumulator is newly added to a set, and
1611 // used to enforce an ordering between them.
1612 std::atomic_uint_fast64_t tape_nesting_id_counter(0);
1613 
1614 typedef struct {
1615   PyObject_HEAD
1616       /* Type-specific fields go here. */
1617       GradientTape* tape;
1618   // A nesting order between GradientTapes and ForwardAccumulators, used to
1619   // ensure that GradientTapes do not watch the products of outer
1620   // ForwardAccumulators.
1621   int64_t nesting_id;
1622 } TFE_Py_Tape;
1623 
TFE_Py_Tape_Delete(PyObject * tape)1624 static void TFE_Py_Tape_Delete(PyObject* tape) {
1625   delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1626   Py_TYPE(tape)->tp_free(tape);
1627 }
1628 
1629 static PyTypeObject TFE_Py_Tape_Type = {
1630     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1631     sizeof(TFE_Py_Tape),                          /* tp_basicsize */
1632     0,                                            /* tp_itemsize */
1633     &TFE_Py_Tape_Delete,                          /* tp_dealloc */
1634 #if PY_VERSION_HEX < 0x03080000
1635     nullptr, /* tp_print */
1636 #else
1637     0, /* tp_vectorcall_offset */
1638 #endif
1639     nullptr,               /* tp_getattr */
1640     nullptr,               /* tp_setattr */
1641     nullptr,               /* tp_reserved */
1642     nullptr,               /* tp_repr */
1643     nullptr,               /* tp_as_number */
1644     nullptr,               /* tp_as_sequence */
1645     nullptr,               /* tp_as_mapping */
1646     nullptr,               /* tp_hash  */
1647     nullptr,               /* tp_call */
1648     nullptr,               /* tp_str */
1649     nullptr,               /* tp_getattro */
1650     nullptr,               /* tp_setattro */
1651     nullptr,               /* tp_as_buffer */
1652     Py_TPFLAGS_DEFAULT,    /* tp_flags */
1653     "TFE_Py_Tape objects", /* tp_doc */
1654 };
1655 
1656 typedef struct {
1657   PyObject_HEAD
1658       /* Type-specific fields go here. */
1659       ForwardAccumulator* accumulator;
1660   // A nesting order between GradientTapes and ForwardAccumulators, used to
1661   // ensure that GradientTapes do not watch the products of outer
1662   // ForwardAccumulators.
1663   int64_t nesting_id;
1664 } TFE_Py_ForwardAccumulator;
1665 
TFE_Py_ForwardAccumulatorDelete(PyObject * accumulator)1666 static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
1667   delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
1668   Py_TYPE(accumulator)->tp_free(accumulator);
1669 }
1670 
1671 static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
1672     PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
1673     sizeof(TFE_Py_ForwardAccumulator),                      /* tp_basicsize */
1674     0,                                                      /* tp_itemsize */
1675     &TFE_Py_ForwardAccumulatorDelete,                       /* tp_dealloc */
1676 #if PY_VERSION_HEX < 0x03080000
1677     nullptr, /* tp_print */
1678 #else
1679     0, /* tp_vectorcall_offset */
1680 #endif
1681     nullptr,                             /* tp_getattr */
1682     nullptr,                             /* tp_setattr */
1683     nullptr,                             /* tp_reserved */
1684     nullptr,                             /* tp_repr */
1685     nullptr,                             /* tp_as_number */
1686     nullptr,                             /* tp_as_sequence */
1687     nullptr,                             /* tp_as_mapping */
1688     nullptr,                             /* tp_hash  */
1689     nullptr,                             /* tp_call */
1690     nullptr,                             /* tp_str */
1691     nullptr,                             /* tp_getattro */
1692     nullptr,                             /* tp_setattro */
1693     nullptr,                             /* tp_as_buffer */
1694     Py_TPFLAGS_DEFAULT,                  /* tp_flags */
1695     "TFE_Py_ForwardAccumulator objects", /* tp_doc */
1696 };
1697 
1698 typedef struct {
1699   PyObject_HEAD
1700       /* Type-specific fields go here. */
1701       VariableWatcher* variable_watcher;
1702 } TFE_Py_VariableWatcher;
1703 
TFE_Py_VariableWatcher_Delete(PyObject * variable_watcher)1704 static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
1705   delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
1706       ->variable_watcher;
1707   Py_TYPE(variable_watcher)->tp_free(variable_watcher);
1708 }
1709 
1710 static PyTypeObject TFE_Py_VariableWatcher_Type = {
1711     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
1712     sizeof(TFE_Py_VariableWatcher),                          /* tp_basicsize */
1713     0,                                                       /* tp_itemsize */
1714     &TFE_Py_VariableWatcher_Delete,                          /* tp_dealloc */
1715 #if PY_VERSION_HEX < 0x03080000
1716     nullptr, /* tp_print */
1717 #else
1718     0, /* tp_vectorcall_offset */
1719 #endif
1720     nullptr,                          /* tp_getattr */
1721     nullptr,                          /* tp_setattr */
1722     nullptr,                          /* tp_reserved */
1723     nullptr,                          /* tp_repr */
1724     nullptr,                          /* tp_as_number */
1725     nullptr,                          /* tp_as_sequence */
1726     nullptr,                          /* tp_as_mapping */
1727     nullptr,                          /* tp_hash  */
1728     nullptr,                          /* tp_call */
1729     nullptr,                          /* tp_str */
1730     nullptr,                          /* tp_getattro */
1731     nullptr,                          /* tp_setattro */
1732     nullptr,                          /* tp_as_buffer */
1733     Py_TPFLAGS_DEFAULT,               /* tp_flags */
1734     "TFE_Py_VariableWatcher objects", /* tp_doc */
1735 };
1736 
1737 // Note: in the current design no mutex is needed here because of the python
1738 // GIL, which is always held when any TFE_Py_* methods are called. We should
1739 // revisit this if/when decide to not hold the GIL while manipulating the tape
1740 // stack.
GetTapeSet()1741 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1742   thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
1743       tape_set;
1744   thread_local ThreadLocalDestructionMarker marker;
1745   if (!marker.alive) {
1746     // This thread is being destroyed. It is unsafe to access tape_set.
1747     return nullptr;
1748   }
1749   if (tape_set == nullptr) {
1750     tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
1751   }
1752   return tape_set.get();
1753 }
1754 
1755 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
GetVariableWatcherSet()1756 GetVariableWatcherSet() {
1757   thread_local std::unique_ptr<
1758       tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
1759       variable_watcher_set;
1760   thread_local ThreadLocalDestructionMarker marker;
1761   if (!marker.alive) {
1762     // This thread is being destroyed. It is unsafe to access
1763     // variable_watcher_set.
1764     return nullptr;
1765   }
1766   if (variable_watcher_set == nullptr) {
1767     variable_watcher_set.reset(
1768         new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
1769   }
1770   return variable_watcher_set.get();
1771 }
1772 
1773 // A linked hash set, where iteration is in insertion order.
1774 //
1775 // Nested accumulators rely on op recording happening in insertion order, so an
1776 // unordered data structure like CompactPointerSet is not suitable. Outer
1777 // accumulators need to observe operations first so they know to watch the inner
1778 // accumulator's jvp computation.
1779 //
1780 // Not thread safe.
1781 class AccumulatorSet {
1782  public:
1783   // Returns true if `element` was newly inserted, false if it already exists.
insert(TFE_Py_ForwardAccumulator * element)1784   bool insert(TFE_Py_ForwardAccumulator* element) {
1785     if (map_.find(element) != map_.end()) {
1786       return false;
1787     }
1788     ListType::iterator it = ordered_.insert(ordered_.end(), element);
1789     map_.insert(std::make_pair(element, it));
1790     return true;
1791   }
1792 
erase(TFE_Py_ForwardAccumulator * element)1793   void erase(TFE_Py_ForwardAccumulator* element) {
1794     MapType::iterator existing = map_.find(element);
1795     if (existing == map_.end()) {
1796       return;
1797     }
1798     ListType::iterator list_position = existing->second;
1799     map_.erase(existing);
1800     ordered_.erase(list_position);
1801   }
1802 
empty() const1803   bool empty() const { return ordered_.empty(); }
1804 
size() const1805   size_t size() const { return ordered_.size(); }
1806 
1807  private:
1808   typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
1809   typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
1810                                    ListType::iterator>
1811       MapType;
1812 
1813  public:
1814   typedef ListType::const_iterator const_iterator;
1815   typedef ListType::const_reverse_iterator const_reverse_iterator;
1816 
begin() const1817   const_iterator begin() const { return ordered_.begin(); }
end() const1818   const_iterator end() const { return ordered_.end(); }
1819 
rbegin() const1820   const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
rend() const1821   const_reverse_iterator rend() const { return ordered_.rend(); }
1822 
1823  private:
1824   MapType map_;
1825   ListType ordered_;
1826 };
1827 
GetAccumulatorSet()1828 AccumulatorSet* GetAccumulatorSet() {
1829   thread_local std::unique_ptr<AccumulatorSet> accumulator_set;
1830   thread_local ThreadLocalDestructionMarker marker;
1831   if (!marker.alive) {
1832     // This thread is being destroyed. It is unsafe to access accumulator_set.
1833     return nullptr;
1834   }
1835   if (accumulator_set == nullptr) {
1836     accumulator_set.reset(new AccumulatorSet);
1837   }
1838   return accumulator_set.get();
1839 }
1840 
HasAccumulator()1841 inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
1842 
HasGradientTape()1843 inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
1844 
HasAccumulatorOrTape()1845 inline bool HasAccumulatorOrTape() {
1846   return HasGradientTape() || HasAccumulator();
1847 }
1848 
1849 // A safe copy of a set, used for tapes and accumulators. The copy is not
1850 // affected by other python threads changing the set of active tapes.
1851 template <typename ContainerType>
1852 class SafeSetCopy {
1853  public:
SafeSetCopy(const ContainerType & to_copy)1854   explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
1855     for (auto* member : set_copy_) {
1856       Py_INCREF(member);
1857     }
1858   }
1859 
~SafeSetCopy()1860   ~SafeSetCopy() {
1861     for (auto* member : set_copy_) {
1862       Py_DECREF(member);
1863     }
1864   }
1865 
begin() const1866   typename ContainerType::const_iterator begin() const {
1867     return set_copy_.begin();
1868   }
1869 
end() const1870   typename ContainerType::const_iterator end() const { return set_copy_.end(); }
1871 
empty() const1872   bool empty() const { return set_copy_.empty(); }
size() const1873   size_t size() const { return set_copy_.size(); }
1874 
1875  protected:
1876   ContainerType set_copy_;
1877 };
1878 
1879 class SafeTapeSet
1880     : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
1881  public:
SafeTapeSet()1882   SafeTapeSet()
1883       : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
1884             *GetTapeSet()) {}
1885 };
1886 
1887 class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
1888  public:
SafeAccumulatorSet()1889   SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
1890 
rbegin() const1891   typename AccumulatorSet::const_reverse_iterator rbegin() const {
1892     return set_copy_.rbegin();
1893   }
1894 
rend() const1895   typename AccumulatorSet::const_reverse_iterator rend() const {
1896     return set_copy_.rend();
1897   }
1898 };
1899 
1900 class SafeVariableWatcherSet
1901     : public SafeSetCopy<
1902           tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
1903  public:
SafeVariableWatcherSet()1904   SafeVariableWatcherSet()
1905       : SafeSetCopy<
1906             tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
1907             *GetVariableWatcherSet()) {}
1908 };
1909 
ThreadTapeIsStopped()1910 bool* ThreadTapeIsStopped() {
1911   thread_local bool thread_tape_is_stopped{false};
1912   return &thread_tape_is_stopped;
1913 }
1914 
TFE_Py_TapeSetStopOnThread()1915 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1916 
TFE_Py_TapeSetRestartOnThread()1917 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1918 
TFE_Py_TapeSetIsStopped()1919 PyObject* TFE_Py_TapeSetIsStopped() {
1920   if (*ThreadTapeIsStopped()) {
1921     Py_RETURN_TRUE;
1922   }
1923   Py_RETURN_FALSE;
1924 }
1925 
TFE_Py_TapeSetNew(PyObject * persistent,PyObject * watch_accessed_variables)1926 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1927                             PyObject* watch_accessed_variables) {
1928   TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1929   if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1930   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1931   tape->tape = new GradientTape(persistent == Py_True,
1932                                 watch_accessed_variables == Py_True);
1933   Py_INCREF(tape);
1934   tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1935   GetTapeSet()->insert(tape);
1936   return reinterpret_cast<PyObject*>(tape);
1937 }
1938 
TFE_Py_TapeSetAdd(PyObject * tape)1939 void TFE_Py_TapeSetAdd(PyObject* tape) {
1940   Py_INCREF(tape);
1941   TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
1942   if (!GetTapeSet()->insert(tfe_tape).second) {
1943     // Already exists in the tape set.
1944     Py_DECREF(tape);
1945   } else {
1946     tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1947   }
1948 }
1949 
TFE_Py_TapeSetIsEmpty()1950 PyObject* TFE_Py_TapeSetIsEmpty() {
1951   if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
1952     Py_RETURN_TRUE;
1953   }
1954   Py_RETURN_FALSE;
1955 }
1956 
TFE_Py_TapeSetRemove(PyObject * tape)1957 void TFE_Py_TapeSetRemove(PyObject* tape) {
1958   auto* stack = GetTapeSet();
1959   if (stack != nullptr) {
1960     stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1961   }
1962   // We kept a reference to the tape in the set to ensure it wouldn't get
1963   // deleted under us; cleaning it up here.
1964   Py_DECREF(tape);
1965 }
1966 
MakeIntList(PyObject * list)1967 static std::vector<int64_t> MakeIntList(PyObject* list) {
1968   if (list == Py_None) {
1969     return {};
1970   }
1971   PyObject* seq = PySequence_Fast(list, "expected a sequence");
1972   if (seq == nullptr) {
1973     return {};
1974   }
1975   int len = PySequence_Size(list);
1976   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1977   std::vector<int64_t> tensor_ids;
1978   tensor_ids.reserve(len);
1979   for (int i = 0; i < len; ++i) {
1980     PyObject* item = seq_array[i];
1981 #if PY_MAJOR_VERSION >= 3
1982     if (PyLong_Check(item)) {
1983 #else
1984     if (PyLong_Check(item) || PyInt_Check(item)) {
1985 #endif
1986       int64_t id = MakeInt(item);
1987       tensor_ids.push_back(id);
1988     } else {
1989       tensor_ids.push_back(-1);
1990     }
1991   }
1992   Py_DECREF(seq);
1993   return tensor_ids;
1994 }
1995 
1996 // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
1997 // null. Returns true on success and false on a Python exception.
1998 bool TensorShapesAndDtypes(PyObject* tensors, std::vector<int64_t>* tensor_ids,
1999                            std::vector<tensorflow::DataType>* dtypes) {
2000   tensorflow::Safe_PyObjectPtr seq(
2001       PySequence_Fast(tensors, "expected a sequence"));
2002   if (seq == nullptr) {
2003     return false;
2004   }
2005   int len = PySequence_Fast_GET_SIZE(seq.get());
2006   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2007   tensor_ids->reserve(len);
2008   dtypes->reserve(len);
2009   for (int i = 0; i < len; ++i) {
2010     PyObject* item = seq_array[i];
2011     tensor_ids->push_back(FastTensorId(item));
2012     dtypes->push_back(tensorflow::PyTensor_DataType(item));
2013   }
2014   return true;
2015 }
2016 
2017 bool TapeCouldPossiblyRecord(PyObject* tensors) {
2018   if (tensors == Py_None) {
2019     return false;
2020   }
2021   if (*ThreadTapeIsStopped()) {
2022     return false;
2023   }
2024   if (!HasAccumulatorOrTape()) {
2025     return false;
2026   }
2027   return true;
2028 }
2029 
2030 bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
2031 
2032 bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
2033 
2034 PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
2035   if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
2036     Py_RETURN_FALSE;
2037   }
2038   // TODO(apassos) consider not building a list and changing the API to check
2039   // each tensor individually.
2040   std::vector<int64_t> tensor_ids;
2041   std::vector<tensorflow::DataType> dtypes;
2042   if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2043     return nullptr;
2044   }
2045   auto& tape_set = *GetTapeSet();
2046   for (TFE_Py_Tape* tape : tape_set) {
2047     if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2048       Py_RETURN_TRUE;
2049     }
2050   }
2051 
2052   Py_RETURN_FALSE;
2053 }
2054 
2055 PyObject* TFE_Py_ForwardAccumulatorPushState() {
2056   auto& forward_accumulators = *GetAccumulatorSet();
2057   for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2058     accumulator->accumulator->PushState();
2059   }
2060   Py_RETURN_NONE;
2061 }
2062 
2063 PyObject* TFE_Py_ForwardAccumulatorPopState() {
2064   auto& forward_accumulators = *GetAccumulatorSet();
2065   for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2066     accumulator->accumulator->PopState();
2067   }
2068   Py_RETURN_NONE;
2069 }
2070 
2071 PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
2072   if (!TapeCouldPossiblyRecord(tensors)) {
2073     return GetPythonObjectFromInt(0);
2074   }
2075   std::vector<int64_t> tensor_ids;
2076   std::vector<tensorflow::DataType> dtypes;
2077   if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2078     return nullptr;
2079   }
2080 
2081   // If there is a persistent tape watching, or if there are multiple tapes
2082   // watching, we'll return immediately indicating that higher-order tape
2083   // gradients are possible.
2084   bool some_tape_watching = false;
2085   if (CouldBackprop()) {
2086     auto& tape_set = *GetTapeSet();
2087     for (TFE_Py_Tape* tape : tape_set) {
2088       if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2089         if (tape->tape->IsPersistent() || some_tape_watching) {
2090           // Either this is the second tape watching, or this tape is
2091           // persistent: higher-order gradients are possible.
2092           return GetPythonObjectFromInt(2);
2093         }
2094         some_tape_watching = true;
2095       }
2096     }
2097   }
2098   if (CouldForwardprop()) {
2099     auto& forward_accumulators = *GetAccumulatorSet();
2100     for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2101       if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
2102         if (some_tape_watching) {
2103           // This is the second tape watching: higher-order gradients are
2104           // possible. Note that there's no equivalent of persistence for
2105           // forward-mode.
2106           return GetPythonObjectFromInt(2);
2107         }
2108         some_tape_watching = true;
2109       }
2110     }
2111   }
2112   if (some_tape_watching) {
2113     // There's exactly one non-persistent tape. The user can request first-order
2114     // gradients but won't be able to get higher-order tape gradients.
2115     return GetPythonObjectFromInt(1);
2116   } else {
2117     // There are no tapes. The user can't request tape gradients.
2118     return GetPythonObjectFromInt(0);
2119   }
2120 }
2121 
2122 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
2123   if (!CouldBackprop()) {
2124     return;
2125   }
2126   int64_t tensor_id = FastTensorId(tensor);
2127   if (PyErr_Occurred()) {
2128     return;
2129   }
2130   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
2131 }
2132 
2133 bool ListContainsNone(PyObject* list) {
2134   if (list == Py_None) return true;
2135   tensorflow::Safe_PyObjectPtr seq(
2136       PySequence_Fast(list, "expected a sequence"));
2137   if (seq == nullptr) {
2138     return false;
2139   }
2140 
2141   int len = PySequence_Size(list);
2142   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2143   for (int i = 0; i < len; ++i) {
2144     PyObject* item = seq_array[i];
2145     if (item == Py_None) return true;
2146   }
2147 
2148   return false;
2149 }
2150 
2151 // As an optimization, the tape generally keeps only the shape and dtype of
2152 // tensors, and uses this information to generate ones/zeros tensors. However,
2153 // some tensors require OnesLike/ZerosLike because their gradients do not match
2154 // their inference shape/dtype.
2155 bool DTypeNeedsHandleData(tensorflow::DataType dtype) {
2156   return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE;
2157 }
2158 
2159 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
2160   if (EagerTensor_CheckExact(tensor)) {
2161     tensorflow::ImmediateExecutionTensorHandle* handle =
2162         tensorflow::unwrap(EagerTensor_Handle(tensor));
2163     int64_t id = PyEagerTensor_ID(tensor);
2164     tensorflow::DataType dtype =
2165         static_cast<tensorflow::DataType>(handle->DataType());
2166     if (DTypeNeedsHandleData(dtype)) {
2167       return PyTapeTensor(id, dtype, tensor);
2168     }
2169 
2170     tensorflow::TensorShape tensor_shape;
2171     int num_dims;
2172     tensorflow::Status status = handle->NumDims(&num_dims);
2173     if (status.ok()) {
2174       for (int i = 0; i < num_dims; ++i) {
2175         int64_t dim_size;
2176         status = handle->Dim(i, &dim_size);
2177         if (!status.ok()) break;
2178         tensor_shape.AddDim(dim_size);
2179       }
2180     }
2181 
2182     if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2183       return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2184                           tensorflow::TensorShape({}));
2185     } else {
2186       return PyTapeTensor(id, dtype, tensor_shape);
2187     }
2188   }
2189   int64_t id = FastTensorId(tensor);
2190   if (PyErr_Occurred()) {
2191     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2192                         tensorflow::TensorShape({}));
2193   }
2194   PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
2195   PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
2196   Py_DECREF(dtype_object);
2197   tensorflow::DataType dtype =
2198       static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
2199   Py_DECREF(dtype_enum);
2200   if (PyErr_Occurred()) {
2201     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2202                         tensorflow::TensorShape({}));
2203   }
2204   static char _shape_tuple[] = "_shape_tuple";
2205   tensorflow::Safe_PyObjectPtr shape_tuple(
2206       PyObject_CallMethod(tensor, _shape_tuple, nullptr));
2207   if (PyErr_Occurred()) {
2208     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2209                         tensorflow::TensorShape({}));
2210   }
2211 
2212   if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) {
2213     return PyTapeTensor(id, dtype, tensor);
2214   }
2215 
2216   auto l = MakeIntList(shape_tuple.get());
2217   // Replace -1, which represents accidental Nones which can occur in graph mode
2218   // and can cause errors in shape construction with 0s.
2219   for (auto& c : l) {
2220     if (c < 0) {
2221       c = 0;
2222     }
2223   }
2224   tensorflow::TensorShape shape(l);
2225   return PyTapeTensor(id, dtype, shape);
2226 }
2227 
2228 // Populates output_info from output_seq, which must come from PySequence_Fast.
2229 //
2230 // Does not take ownership of output_seq. Returns true on success and false if a
2231 // Python exception has been set.
2232 bool TapeTensorsFromTensorSequence(PyObject* output_seq,
2233                                    std::vector<PyTapeTensor>* output_info) {
2234   Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
2235   PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2236   output_info->reserve(output_len);
2237   for (Py_ssize_t i = 0; i < output_len; ++i) {
2238     output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
2239     if (PyErr_Occurred() != nullptr) {
2240       return false;
2241     }
2242   }
2243   return true;
2244 }
2245 
2246 std::vector<int64_t> MakeTensorIDList(PyObject* tensors) {
2247   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2248   if (seq == nullptr) {
2249     return {};
2250   }
2251   int len = PySequence_Fast_GET_SIZE(seq);
2252   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2253   std::vector<int64_t> list;
2254   list.reserve(len);
2255   for (int i = 0; i < len; ++i) {
2256     PyObject* tensor = seq_array[i];
2257     list.push_back(FastTensorId(tensor));
2258     if (PyErr_Occurred()) {
2259       Py_DECREF(seq);
2260       return list;
2261     }
2262   }
2263   Py_DECREF(seq);
2264   return list;
2265 }
2266 
2267 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
2268   if (!CouldBackprop()) {
2269     return;
2270   }
2271   for (TFE_Py_Tape* tape : SafeTapeSet()) {
2272     tape->tape->VariableAccessed(variable);
2273   }
2274 }
2275 
2276 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
2277   if (!CouldBackprop()) {
2278     return;
2279   }
2280   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
2281 }
2282 
2283 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
2284   return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
2285 }
2286 
2287 PyObject* TFE_Py_VariableWatcherNew() {
2288   TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
2289   if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
2290   TFE_Py_VariableWatcher* variable_watcher =
2291       PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
2292   variable_watcher->variable_watcher = new VariableWatcher();
2293   Py_INCREF(variable_watcher);
2294   GetVariableWatcherSet()->insert(variable_watcher);
2295   return reinterpret_cast<PyObject*>(variable_watcher);
2296 }
2297 
2298 void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
2299   auto* stack = GetVariableWatcherSet();
2300   stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
2301   // We kept a reference to the variable watcher in the set to ensure it
2302   // wouldn't get deleted under us; cleaning it up here.
2303   Py_DECREF(variable_watcher);
2304 }
2305 
2306 void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
2307   for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
2308     variable_watcher->variable_watcher->WatchVariable(variable);
2309   }
2310 }
2311 
2312 PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
2313   return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
2314       ->variable_watcher->GetVariablesAsPyTuple();
2315 }
2316 
2317 namespace {
2318 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
2319   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2320   if (seq == nullptr) {
2321     return {};
2322   }
2323   int len = PySequence_Fast_GET_SIZE(seq);
2324   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2325   std::vector<tensorflow::DataType> list;
2326   list.reserve(len);
2327   for (int i = 0; i < len; ++i) {
2328     PyObject* tensor = seq_array[i];
2329     list.push_back(tensorflow::PyTensor_DataType(tensor));
2330   }
2331   Py_DECREF(seq);
2332   return list;
2333 }
2334 
2335 PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
2336                                            PyObject* weak_tensor_ref) {
2337   auto* accumulator_set = GetAccumulatorSet();
2338   if (accumulator_set != nullptr) {
2339     int64_t parsed_tensor_id = MakeInt(tensor_id);
2340     for (TFE_Py_ForwardAccumulator* accumulator : *accumulator_set) {
2341       accumulator->accumulator->DeleteGradient(parsed_tensor_id);
2342     }
2343   }
2344   Py_DECREF(weak_tensor_ref);
2345   Py_DECREF(tensor_id);
2346   Py_INCREF(Py_None);
2347   return Py_None;
2348 }
2349 
2350 static PyMethodDef forward_accumulator_delete_gradient_method_def = {
2351     "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
2352     METH_O, "ForwardAccumulatorDeleteGradient"};
2353 
2354 void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) {
2355   tensorflow::Safe_PyObjectPtr callback(
2356       PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
2357                       PyLong_FromLong(tensor_id)));
2358   // We need to keep a reference to the weakref active if we want our callback
2359   // called. The callback itself now owns the weakref object and the tensor ID
2360   // object.
2361   PyWeakref_NewRef(tensor, callback.get());
2362 }
2363 
2364 void TapeSetRecordBackprop(
2365     const string& op_type, const std::vector<PyTapeTensor>& output_info,
2366     const std::vector<int64_t>& input_ids,
2367     const std::vector<tensorflow::DataType>& input_dtypes,
2368     const std::function<PyBackwardFunction*()>& backward_function_getter,
2369     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2370     tensorflow::uint64 max_gradient_tape_id) {
2371   if (!CouldBackprop()) {
2372     return;
2373   }
2374   for (TFE_Py_Tape* tape : SafeTapeSet()) {
2375     if (tape->nesting_id < max_gradient_tape_id) {
2376       tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
2377                                   backward_function_getter,
2378                                   backward_function_killer);
2379     }
2380   }
2381 }
2382 
2383 bool TapeSetRecordForwardprop(
2384     const string& op_type, PyObject* output_seq,
2385     const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
2386     const std::vector<int64_t>& input_ids,
2387     const std::vector<tensorflow::DataType>& input_dtypes,
2388     const std::function<PyBackwardFunction*()>& backward_function_getter,
2389     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2390     const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
2391     PyObject* forwardprop_output_indices,
2392     tensorflow::uint64* max_gradient_tape_id) {
2393   *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
2394   if (!CouldForwardprop()) {
2395     return true;
2396   }
2397   auto accumulator_set = SafeAccumulatorSet();
2398   tensorflow::Safe_PyObjectPtr input_seq(
2399       PySequence_Fast(input_tensors, "expected a sequence of tensors"));
2400   if (input_seq == nullptr || PyErr_Occurred()) return false;
2401   Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
2402   PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2403   for (int i = 0; i < output_info.size(); ++i) {
2404     RegisterForwardAccumulatorCleanup(output_seq_array[i],
2405                                       output_info[i].GetID());
2406   }
2407   if (forwardprop_output_indices != nullptr &&
2408       forwardprop_output_indices != Py_None) {
2409     tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
2410         forwardprop_output_indices, "Expected a sequence of indices"));
2411     if (indices_fast == nullptr || PyErr_Occurred()) {
2412       return false;
2413     }
2414     if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
2415         accumulator_set.size()) {
2416       MaybeRaiseExceptionFromStatus(
2417           tensorflow::errors::Internal(
2418               "Accumulators were added or removed from the active set "
2419               "between packing and unpacking."),
2420           nullptr);
2421     }
2422     PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
2423     Py_ssize_t accumulator_index = 0;
2424     for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
2425          it != accumulator_set.rend(); ++it, ++accumulator_index) {
2426       tensorflow::Safe_PyObjectPtr jvp_index_seq(
2427           PySequence_Fast(indices_fast_array[accumulator_index],
2428                           "Expected a sequence of jvp indices."));
2429       if (jvp_index_seq == nullptr || PyErr_Occurred()) {
2430         return false;
2431       }
2432       Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
2433       PyObject** jvp_index_seq_array =
2434           PySequence_Fast_ITEMS(jvp_index_seq.get());
2435       for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
2436         PyObject* tuple = jvp_index_seq_array[jvp_index];
2437         int64_t primal_tensor_id =
2438             output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
2439         (*it)->accumulator->Watch(
2440             primal_tensor_id,
2441             output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
2442       }
2443     }
2444   } else {
2445     std::vector<PyTapeTensor> input_info;
2446     input_info.reserve(input_len);
2447     PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
2448     for (Py_ssize_t i = 0; i < input_len; ++i) {
2449       input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
2450     }
2451     for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
2452       tensorflow::Status status = accumulator->accumulator->Accumulate(
2453           op_type, input_info, output_info, input_ids, input_dtypes,
2454           forward_function, backward_function_getter, backward_function_killer);
2455       if (PyErr_Occurred()) return false;  // Don't swallow Python exceptions.
2456       if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2457         return false;
2458       }
2459       if (accumulator->accumulator->BusyAccumulating()) {
2460         // Ensure inner accumulators don't see outer accumulators' jvps. This
2461         // mostly happens on its own, with some potentially surprising
2462         // exceptions, so the blanket policy is for consistency.
2463         *max_gradient_tape_id = accumulator->nesting_id;
2464         break;
2465       }
2466     }
2467   }
2468   return true;
2469 }
2470 
2471 PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
2472   PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
2473   for (int i = 0; i < input_tangents.size(); ++i) {
2474     PyObject* element;
2475     if (input_tangents[i] == nullptr) {
2476       element = Py_None;
2477     } else {
2478       element = input_tangents[i];
2479     }
2480     Py_INCREF(element);
2481     PyTuple_SET_ITEM(py_input_tangents, i, element);
2482   }
2483   return py_input_tangents;
2484 }
2485 
2486 tensorflow::Status ParseTangentOutputs(
2487     PyObject* user_output, std::vector<PyObject*>* output_tangents) {
2488   if (user_output == Py_None) {
2489     // No connected gradients.
2490     return ::tensorflow::OkStatus();
2491   }
2492   tensorflow::Safe_PyObjectPtr fast_result(
2493       PySequence_Fast(user_output, "expected a sequence of forward gradients"));
2494   if (fast_result == nullptr) {
2495     return tensorflow::errors::InvalidArgument(
2496         "forward gradient function did not return a sequence.");
2497   }
2498   int len = PySequence_Fast_GET_SIZE(fast_result.get());
2499   PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
2500   output_tangents->reserve(len);
2501   for (int i = 0; i < len; ++i) {
2502     PyObject* item = fast_result_array[i];
2503     if (item == Py_None) {
2504       output_tangents->push_back(nullptr);
2505     } else {
2506       Py_INCREF(item);
2507       output_tangents->push_back(item);
2508     }
2509   }
2510   return ::tensorflow::OkStatus();
2511 }
2512 
2513 // Calls the registered forward_gradient_function, computing `output_tangents`
2514 // from `input_tangents`. `output_tangents` must not be null.
2515 //
2516 // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
2517 // the forward function is being called.
2518 tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
2519                                    PyObject* inputs, PyObject* results,
2520                                    const std::vector<PyObject*>& input_tangents,
2521                                    std::vector<PyObject*>* output_tangents,
2522                                    bool use_batch) {
2523   if (forward_gradient_function == nullptr) {
2524     return tensorflow::errors::Internal(
2525         "No forward gradient function registered.");
2526   }
2527   tensorflow::Safe_PyObjectPtr py_input_tangents(
2528       TangentsAsPyTuple(input_tangents));
2529 
2530   // Normalize the input sequence to a tuple so it works with function
2531   // caching; otherwise it may be an opaque _InputList object.
2532   tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
2533   PyObject* to_batch = (use_batch) ? Py_True : Py_False;
2534   tensorflow::Safe_PyObjectPtr callback_args(
2535       Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
2536                     py_input_tangents.get(), to_batch));
2537   tensorflow::Safe_PyObjectPtr py_result(
2538       PyObject_CallObject(forward_gradient_function, callback_args.get()));
2539   if (py_result == nullptr || PyErr_Occurred()) {
2540     return tensorflow::errors::Internal(
2541         "forward gradient function threw exceptions");
2542   }
2543   return ParseTangentOutputs(py_result.get(), output_tangents);
2544 }
2545 
2546 // Like CallJVPFunction, but calls a pre-bound forward function.
2547 // These are passed in from a record_gradient argument.
2548 tensorflow::Status CallOpSpecificJVPFunction(
2549     PyObject* op_specific_forward_function,
2550     const std::vector<PyObject*>& input_tangents,
2551     std::vector<PyObject*>* output_tangents) {
2552   tensorflow::Safe_PyObjectPtr py_input_tangents(
2553       TangentsAsPyTuple(input_tangents));
2554 
2555   tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
2556       op_specific_forward_function, py_input_tangents.get()));
2557   if (py_result == nullptr || PyErr_Occurred()) {
2558     return tensorflow::errors::Internal(
2559         "forward gradient function threw exceptions");
2560   }
2561   return ParseTangentOutputs(py_result.get(), output_tangents);
2562 }
2563 
2564 bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
2565   if (PyBytes_Check(op_type)) {
2566     *op_type_string = PyBytes_AsString(op_type);
2567   } else if (PyUnicode_Check(op_type)) {
2568 #if PY_MAJOR_VERSION >= 3
2569     *op_type_string = PyUnicode_AsUTF8(op_type);
2570 #else
2571     PyObject* py_str = PyUnicode_AsUTF8String(op_type);
2572     if (py_str == nullptr) {
2573       return false;
2574     }
2575     *op_type_string = PyBytes_AS_STRING(py_str);
2576     Py_DECREF(py_str);
2577 #endif
2578   } else {
2579     PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
2580     return false;
2581   }
2582   return true;
2583 }
2584 
2585 bool TapeSetRecordOperation(
2586     PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
2587     const std::vector<int64_t>& input_ids,
2588     const std::vector<tensorflow::DataType>& input_dtypes,
2589     const std::function<PyBackwardFunction*()>& backward_function_getter,
2590     const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2591     const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
2592   std::vector<PyTapeTensor> output_info;
2593   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2594       output_tensors, "expected a sequence of integer tensor ids"));
2595   if (PyErr_Occurred() ||
2596       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2597     return false;
2598   }
2599   string op_type_str;
2600   if (!ParseOpTypeString(op_type, &op_type_str)) {
2601     return false;
2602   }
2603   tensorflow::uint64 max_gradient_tape_id;
2604   if (!TapeSetRecordForwardprop(
2605           op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2606           input_dtypes, backward_function_getter, backward_function_killer,
2607           forward_function, nullptr /* No special-cased jvps. */,
2608           &max_gradient_tape_id)) {
2609     return false;
2610   }
2611   TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2612                         backward_function_getter, backward_function_killer,
2613                         max_gradient_tape_id);
2614   return true;
2615 }
2616 }  // namespace
2617 
2618 PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
2619                                         PyObject* output_tensors,
2620                                         PyObject* input_tensors,
2621                                         PyObject* backward_function,
2622                                         PyObject* forward_function) {
2623   if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
2624     Py_RETURN_NONE;
2625   }
2626   std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2627   if (PyErr_Occurred()) return nullptr;
2628 
2629   std::vector<tensorflow::DataType> input_dtypes =
2630       MakeTensorDtypeList(input_tensors);
2631   if (PyErr_Occurred()) return nullptr;
2632 
2633   std::function<PyBackwardFunction*()> backward_function_getter(
2634       [backward_function]() {
2635         Py_INCREF(backward_function);
2636         PyBackwardFunction* function = new PyBackwardFunction(
2637             [backward_function](PyObject* out_grads,
2638                                 const std::vector<int64_t>& unused) {
2639               return PyObject_CallObject(backward_function, out_grads);
2640             });
2641         return function;
2642       });
2643   std::function<void(PyBackwardFunction*)> backward_function_killer(
2644       [backward_function](PyBackwardFunction* py_backward_function) {
2645         Py_DECREF(backward_function);
2646         delete py_backward_function;
2647       });
2648 
2649   if (forward_function == Py_None) {
2650     if (!TapeSetRecordOperation(
2651             op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2652             backward_function_getter, backward_function_killer,
2653             nullptr /* No special-cased forward function */)) {
2654       return nullptr;
2655     }
2656   } else {
2657     tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
2658         [forward_function](const std::vector<PyObject*>& input_tangents,
2659                            std::vector<PyObject*>* output_tangents,
2660                            bool use_batch = false) {
2661           return CallOpSpecificJVPFunction(forward_function, input_tangents,
2662                                            output_tangents);
2663         });
2664     if (!TapeSetRecordOperation(
2665             op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2666             backward_function_getter, backward_function_killer,
2667             &wrapped_forward_function)) {
2668       return nullptr;
2669     }
2670   }
2671   Py_RETURN_NONE;
2672 }
2673 
2674 PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
2675     PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
2676     PyObject* backward_function, PyObject* forwardprop_output_indices) {
2677   if (!HasAccumulator() || *ThreadTapeIsStopped()) {
2678     Py_RETURN_NONE;
2679   }
2680   std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2681   if (PyErr_Occurred()) return nullptr;
2682 
2683   std::vector<tensorflow::DataType> input_dtypes =
2684       MakeTensorDtypeList(input_tensors);
2685   if (PyErr_Occurred()) return nullptr;
2686 
2687   std::function<PyBackwardFunction*()> backward_function_getter(
2688       [backward_function]() {
2689         Py_INCREF(backward_function);
2690         PyBackwardFunction* function = new PyBackwardFunction(
2691             [backward_function](PyObject* out_grads,
2692                                 const std::vector<int64_t>& unused) {
2693               return PyObject_CallObject(backward_function, out_grads);
2694             });
2695         return function;
2696       });
2697   std::function<void(PyBackwardFunction*)> backward_function_killer(
2698       [backward_function](PyBackwardFunction* py_backward_function) {
2699         Py_DECREF(backward_function);
2700         delete py_backward_function;
2701       });
2702   std::vector<PyTapeTensor> output_info;
2703   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2704       output_tensors, "expected a sequence of integer tensor ids"));
2705   if (PyErr_Occurred() ||
2706       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2707     return nullptr;
2708   }
2709   string op_type_str;
2710   if (!ParseOpTypeString(op_type, &op_type_str)) {
2711     return nullptr;
2712   }
2713   tensorflow::uint64 max_gradient_tape_id;
2714   if (!TapeSetRecordForwardprop(
2715           op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2716           input_dtypes, backward_function_getter, backward_function_killer,
2717           nullptr /* no special-cased forward function */,
2718           forwardprop_output_indices, &max_gradient_tape_id)) {
2719     return nullptr;
2720   }
2721   Py_RETURN_NONE;
2722 }
2723 
2724 PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
2725                                                 PyObject* output_tensors,
2726                                                 PyObject* input_tensors,
2727                                                 PyObject* backward_function) {
2728   if (!CouldBackprop()) {
2729     Py_RETURN_NONE;
2730   }
2731   std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2732   if (PyErr_Occurred()) return nullptr;
2733 
2734   std::vector<tensorflow::DataType> input_dtypes =
2735       MakeTensorDtypeList(input_tensors);
2736   if (PyErr_Occurred()) return nullptr;
2737 
2738   std::function<PyBackwardFunction*()> backward_function_getter(
2739       [backward_function]() {
2740         Py_INCREF(backward_function);
2741         PyBackwardFunction* function = new PyBackwardFunction(
2742             [backward_function](PyObject* out_grads,
2743                                 const std::vector<int64_t>& unused) {
2744               return PyObject_CallObject(backward_function, out_grads);
2745             });
2746         return function;
2747       });
2748   std::function<void(PyBackwardFunction*)> backward_function_killer(
2749       [backward_function](PyBackwardFunction* py_backward_function) {
2750         Py_DECREF(backward_function);
2751         delete py_backward_function;
2752       });
2753   std::vector<PyTapeTensor> output_info;
2754   tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2755       output_tensors, "expected a sequence of integer tensor ids"));
2756   if (PyErr_Occurred() ||
2757       !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2758     return nullptr;
2759   }
2760   string op_type_str;
2761   if (!ParseOpTypeString(op_type, &op_type_str)) {
2762     return nullptr;
2763   }
2764   TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2765                         backward_function_getter, backward_function_killer,
2766                         // No filtering based on relative ordering with forward
2767                         // accumulators.
2768                         std::numeric_limits<tensorflow::uint64>::max());
2769   Py_RETURN_NONE;
2770 }
2771 
2772 void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) {
2773   auto* tape_set = GetTapeSet();
2774   if (tape_set == nullptr) {
2775     // Current thread is being destructed, and the tape set has already
2776     // been cleared.
2777     return;
2778   }
2779   for (TFE_Py_Tape* tape : *tape_set) {
2780     tape->tape->DeleteTrace(tensor_id);
2781   }
2782 }
2783 
2784 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
2785   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2786   if (seq == nullptr) {
2787     return {};
2788   }
2789   int len = PySequence_Fast_GET_SIZE(seq);
2790   PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2791   std::vector<PyObject*> list(seq_array, seq_array + len);
2792   Py_DECREF(seq);
2793   return list;
2794 }
2795 
2796 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
2797                               PyObject* sources, PyObject* output_gradients,
2798                               PyObject* sources_raw,
2799                               PyObject* unconnected_gradients,
2800                               TF_Status* status) {
2801   TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
2802   if (!tape_obj->tape->IsPersistent()) {
2803     auto* tape_set = GetTapeSet();
2804     if (tape_set->find(tape_obj) != tape_set->end()) {
2805       PyErr_SetString(PyExc_RuntimeError,
2806                       "gradient() cannot be invoked within the "
2807                       "GradientTape context (i.e., while operations are being "
2808                       "recorded). Either move the call to gradient() to be "
2809                       "outside the 'with tf.GradientTape' block, or "
2810                       "use a persistent tape: "
2811                       "'with tf.GradientTape(persistent=true)'");
2812       return nullptr;
2813     }
2814   }
2815 
2816   std::vector<int64_t> target_vec = MakeTensorIDList(target);
2817   if (PyErr_Occurred()) {
2818     return nullptr;
2819   }
2820   std::vector<int64_t> sources_vec = MakeTensorIDList(sources);
2821   if (PyErr_Occurred()) {
2822     return nullptr;
2823   }
2824   tensorflow::gtl::FlatSet<int64_t> sources_set(sources_vec.begin(),
2825                                                 sources_vec.end());
2826 
2827   tensorflow::Safe_PyObjectPtr seq =
2828       tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
2829   int len = PySequence_Fast_GET_SIZE(seq.get());
2830   PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2831   std::unordered_map<int64_t, PyTapeTensor> source_tensors_that_are_targets;
2832   for (int i = 0; i < len; ++i) {
2833     int64_t target_id = target_vec[i];
2834     if (sources_set.find(target_id) != sources_set.end()) {
2835       auto tensor = seq_array[i];
2836       source_tensors_that_are_targets.insert(
2837           std::make_pair(target_id, TapeTensorFromTensor(tensor)));
2838     }
2839     if (PyErr_Occurred()) {
2840       return nullptr;
2841     }
2842   }
2843   if (PyErr_Occurred()) {
2844     return nullptr;
2845   }
2846 
2847   std::vector<PyObject*> outgrad_vec;
2848   if (output_gradients != Py_None) {
2849     outgrad_vec = MakeTensorList(output_gradients);
2850     if (PyErr_Occurred()) {
2851       return nullptr;
2852     }
2853     for (PyObject* tensor : outgrad_vec) {
2854       // Calling the backward function will eat a reference to the tensors in
2855       // outgrad_vec, so we need to increase their reference count.
2856       Py_INCREF(tensor);
2857     }
2858   }
2859   std::vector<PyObject*> result(sources_vec.size());
2860   status->status = tape_obj->tape->ComputeGradient(
2861       *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
2862       outgrad_vec, absl::MakeSpan(result));
2863   if (!status->status.ok()) {
2864     if (PyErr_Occurred()) {
2865       // Do not propagate the erroneous status as that would swallow the
2866       // exception which caused the problem.
2867       status->status = ::tensorflow::OkStatus();
2868     }
2869     return nullptr;
2870   }
2871 
2872   bool unconnected_gradients_zero =
2873       strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
2874   std::vector<PyObject*> sources_obj;
2875   if (unconnected_gradients_zero) {
2876     // Uses the "raw" sources here so it can properly make a zeros tensor even
2877     // if there are resource variables as sources.
2878     sources_obj = MakeTensorList(sources_raw);
2879   }
2880 
2881   if (!result.empty()) {
2882     PyObject* py_result = PyList_New(result.size());
2883     tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
2884     for (int i = 0; i < result.size(); ++i) {
2885       if (result[i] == nullptr) {
2886         if (unconnected_gradients_zero) {
2887           // generate a zeros tensor in the shape of sources[i]
2888           tensorflow::DataType dtype =
2889               tensorflow::PyTensor_DataType(sources_obj[i]);
2890           PyTapeTensor tensor =
2891               PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
2892           result[i] = tensor.ZerosLike();
2893         } else {
2894           Py_INCREF(Py_None);
2895           result[i] = Py_None;
2896         }
2897       } else if (seen_results.find(result[i]) != seen_results.end()) {
2898         Py_INCREF(result[i]);
2899       }
2900       seen_results.insert(result[i]);
2901       PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
2902     }
2903     return py_result;
2904   }
2905   return PyList_New(0);
2906 }
2907 
2908 PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
2909   TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
2910   if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
2911   TFE_Py_ForwardAccumulator* accumulator =
2912       PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
2913   if (py_vspace == nullptr) {
2914     MaybeRaiseExceptionFromStatus(
2915         tensorflow::errors::Internal(
2916             "ForwardAccumulator requires a PyVSpace to be registered."),
2917         nullptr);
2918   }
2919   accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
2920   return reinterpret_cast<PyObject*>(accumulator);
2921 }
2922 
2923 PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
2924   TFE_Py_ForwardAccumulator* c_accumulator(
2925       reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2926   c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
2927   if (GetAccumulatorSet()->insert(c_accumulator)) {
2928     Py_INCREF(accumulator);
2929     Py_RETURN_NONE;
2930   } else {
2931     MaybeRaiseExceptionFromStatus(
2932         tensorflow::errors::Internal(
2933             "A ForwardAccumulator was added to the active set twice."),
2934         nullptr);
2935     return nullptr;
2936   }
2937 }
2938 
2939 void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
2940   auto* accumulator_set = GetAccumulatorSet();
2941   if (accumulator_set != nullptr) {
2942     accumulator_set->erase(
2943         reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2944   }
2945   Py_DECREF(accumulator);
2946 }
2947 
2948 void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
2949                                     PyObject* tangent) {
2950   int64_t tensor_id = FastTensorId(tensor);
2951   reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2952       ->accumulator->Watch(tensor_id, tangent);
2953   RegisterForwardAccumulatorCleanup(tensor, tensor_id);
2954 }
2955 
2956 // Returns a new reference to the JVP Tensor.
2957 PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
2958                                        PyObject* tensor) {
2959   PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2960                       ->accumulator->FetchJVP(FastTensorId(tensor));
2961   if (jvp == nullptr) {
2962     jvp = Py_None;
2963   }
2964   Py_INCREF(jvp);
2965   return jvp;
2966 }
2967 
2968 PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
2969   if (!TapeCouldPossiblyRecord(tensors)) {
2970     tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
2971     tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
2972     return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
2973   }
2974   auto& accumulators = *GetAccumulatorSet();
2975   tensorflow::Safe_PyObjectPtr tensors_fast(
2976       PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
2977   if (tensors_fast == nullptr || PyErr_Occurred()) {
2978     return nullptr;
2979   }
2980   std::vector<int64_t> augmented_input_ids;
2981   int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
2982   PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
2983   for (Py_ssize_t position = 0; position < len; ++position) {
2984     PyObject* input = tensors_fast_array[position];
2985     if (input == Py_None) {
2986       continue;
2987     }
2988     tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
2989     if (input_dtype == tensorflow::DT_INVALID) {
2990       return nullptr;
2991     }
2992     augmented_input_ids.push_back(FastTensorId(input));
2993   }
2994   if (PyErr_Occurred()) {
2995     return nullptr;
2996   }
2997   // Find the innermost accumulator such that all outer accumulators are
2998   // recording. Any more deeply nested accumulators will not have their JVPs
2999   // saved.
3000   AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
3001   for (; innermost_all_recording != accumulators.end();
3002        ++innermost_all_recording) {
3003     if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
3004       break;
3005     }
3006   }
3007   AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
3008       innermost_all_recording);
3009 
3010   bool saving_jvps = false;
3011   tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
3012   std::vector<PyObject*> new_tensors;
3013   Py_ssize_t accumulator_index = 0;
3014   // Start with the innermost accumulators to give outer accumulators a chance
3015   // to find their higher-order JVPs.
3016   for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
3017        it != accumulators.rend(); ++it, ++accumulator_index) {
3018     std::vector<int64_t> new_input_ids;
3019     std::vector<std::pair<int64_t, int64_t>> accumulator_indices;
3020     if (it == reverse_innermost_all_recording) {
3021       saving_jvps = true;
3022     }
3023     if (saving_jvps) {
3024       for (int input_index = 0; input_index < augmented_input_ids.size();
3025            ++input_index) {
3026         int64_t existing_input = augmented_input_ids[input_index];
3027         PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
3028         if (jvp != nullptr) {
3029           new_tensors.push_back(jvp);
3030           new_input_ids.push_back(FastTensorId(jvp));
3031           accumulator_indices.emplace_back(
3032               input_index,
3033               augmented_input_ids.size() + new_input_ids.size() - 1);
3034         }
3035       }
3036     }
3037     tensorflow::Safe_PyObjectPtr accumulator_indices_py(
3038         PyTuple_New(accumulator_indices.size()));
3039     for (int i = 0; i < accumulator_indices.size(); ++i) {
3040       tensorflow::Safe_PyObjectPtr from_index(
3041           GetPythonObjectFromInt(accumulator_indices[i].first));
3042       tensorflow::Safe_PyObjectPtr to_index(
3043           GetPythonObjectFromInt(accumulator_indices[i].second));
3044       PyTuple_SetItem(accumulator_indices_py.get(), i,
3045                       PyTuple_Pack(2, from_index.get(), to_index.get()));
3046     }
3047     PyTuple_SetItem(all_indices.get(), accumulator_index,
3048                     accumulator_indices_py.release());
3049     augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
3050                                new_input_ids.end());
3051   }
3052 
3053   tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
3054   for (int i = 0; i < new_tensors.size(); ++i) {
3055     PyObject* jvp = new_tensors[i];
3056     Py_INCREF(jvp);
3057     PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
3058   }
3059   return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
3060 }
3061 
3062 namespace {
3063 
3064 // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
3065 enum FastPathExecuteArgIndex {
3066   FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
3067   FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
3068   FAST_PATH_EXECUTE_ARG_NAME = 2,
3069   FAST_PATH_EXECUTE_ARG_INPUT_START = 3
3070 };
3071 
3072 PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
3073 #if PY_MAJOR_VERSION >= 3
3074   return PyUnicode_FromStringAndSize(s.data(), s.size());
3075 #else
3076   return PyBytes_FromStringAndSize(s.data(), s.size());
3077 #endif
3078 }
3079 
3080 bool CheckResourceVariable(PyObject* item) {
3081   if (tensorflow::swig::IsResourceVariable(item)) {
3082     tensorflow::Safe_PyObjectPtr handle(
3083         PyObject_GetAttrString(item, "_handle"));
3084     return EagerTensor_CheckExact(handle.get());
3085   }
3086 
3087   return false;
3088 }
3089 
3090 bool IsNumberType(PyObject* item) {
3091 #if PY_MAJOR_VERSION >= 3
3092   return PyFloat_Check(item) || PyLong_Check(item);
3093 #else
3094   return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
3095 #endif
3096 }
3097 
3098 bool CheckOneInput(PyObject* item) {
3099   if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
3100       PyArray_Check(item) || IsNumberType(item)) {
3101     return true;
3102   }
3103 
3104   // Sequences are not properly handled. Sequences with purely python numeric
3105   // types work, but sequences with mixes of EagerTensors and python numeric
3106   // types don't work.
3107   // TODO(nareshmodi): fix
3108   return false;
3109 }
3110 
3111 bool CheckInputsOk(PyObject* seq, int start_index,
3112                    const tensorflow::OpDef& op_def) {
3113   for (int i = 0; i < op_def.input_arg_size(); i++) {
3114     PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
3115     if (!op_def.input_arg(i).number_attr().empty() ||
3116         !op_def.input_arg(i).type_list_attr().empty()) {
3117       // This item should be a seq input.
3118       if (!PySequence_Check(item)) {
3119         VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3120                 << "\", Input \"" << op_def.input_arg(i).name()
3121                 << "\" since we expected a sequence, but got "
3122                 << item->ob_type->tp_name;
3123         return false;
3124       }
3125       tensorflow::Safe_PyObjectPtr fast_item(
3126           PySequence_Fast(item, "Could not parse sequence."));
3127       if (fast_item.get() == nullptr) {
3128         return false;
3129       }
3130       int len = PySequence_Fast_GET_SIZE(fast_item.get());
3131       PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3132       for (Py_ssize_t j = 0; j < len; j++) {
3133         PyObject* inner_item = fast_item_array[j];
3134         if (!CheckOneInput(inner_item)) {
3135           VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3136                   << "\", Input \"" << op_def.input_arg(i).name()
3137                   << "\", Index " << j
3138                   << " since we expected an EagerTensor/ResourceVariable, "
3139                      "but got "
3140                   << inner_item->ob_type->tp_name;
3141           return false;
3142         }
3143       }
3144     } else if (!CheckOneInput(item)) {
3145       VLOG(1)
3146           << "Falling back to slow path for Op \"" << op_def.name()
3147           << "\", Input \"" << op_def.input_arg(i).name()
3148           << "\" since we expected an EagerTensor/ResourceVariable, but got "
3149           << item->ob_type->tp_name;
3150       return false;
3151     }
3152   }
3153 
3154   return true;
3155 }
3156 
3157 tensorflow::DataType MaybeGetDType(PyObject* item) {
3158   if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
3159     return tensorflow::PyTensor_DataType(item);
3160   }
3161 
3162   return tensorflow::DT_INVALID;
3163 }
3164 
3165 tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
3166                                           FastPathOpExecInfo* op_exec_info) {
3167   auto cached_it = op_exec_info->cached_dtypes.find(attr);
3168   if (cached_it != op_exec_info->cached_dtypes.end()) {
3169     return cached_it->second;
3170   }
3171 
3172   auto it = op_exec_info->attr_to_inputs_map->find(attr);
3173   if (it == op_exec_info->attr_to_inputs_map->end()) {
3174     // No other inputs - this should never happen.
3175     return tensorflow::DT_INVALID;
3176   }
3177 
3178   for (const auto& input_info : it->second) {
3179     PyObject* item = PyTuple_GET_ITEM(
3180         op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
3181     if (input_info.is_list) {
3182       tensorflow::Safe_PyObjectPtr fast_item(
3183           PySequence_Fast(item, "Unable to allocate"));
3184       int len = PySequence_Fast_GET_SIZE(fast_item.get());
3185       PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3186       for (int i = 0; i < len; i++) {
3187         auto dtype = MaybeGetDType(fast_item_array[i]);
3188         if (dtype != tensorflow::DT_INVALID) return dtype;
3189       }
3190     } else {
3191       auto dtype = MaybeGetDType(item);
3192       if (dtype != tensorflow::DT_INVALID) return dtype;
3193     }
3194   }
3195 
3196   auto default_it = op_exec_info->default_dtypes->find(attr);
3197   if (default_it != op_exec_info->default_dtypes->end()) {
3198     return default_it->second;
3199   }
3200 
3201   return tensorflow::DT_INVALID;
3202 }
3203 
3204 PyObject* CopySequenceSettingIndicesToNull(
3205     PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
3206   tensorflow::Safe_PyObjectPtr fast_seq(
3207       PySequence_Fast(seq, "unable to allocate"));
3208   int len = PySequence_Fast_GET_SIZE(fast_seq.get());
3209   PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
3210   PyObject* result = PyTuple_New(len);
3211   for (int i = 0; i < len; i++) {
3212     PyObject* item;
3213     if (indices.find(i) != indices.end()) {
3214       item = Py_None;
3215     } else {
3216       item = fast_seq_array[i];
3217     }
3218     Py_INCREF(item);
3219     PyTuple_SET_ITEM(result, i, item);
3220   }
3221   return result;
3222 }
3223 
3224 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
3225                          PyObject* results,
3226                          PyObject* forward_pass_name_scope = nullptr) {
3227   std::vector<int64_t> input_ids = MakeTensorIDList(inputs);
3228   if (PyErr_Occurred()) return nullptr;
3229   std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
3230   if (PyErr_Occurred()) return nullptr;
3231 
3232   bool should_record = false;
3233   for (TFE_Py_Tape* tape : SafeTapeSet()) {
3234     if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
3235       should_record = true;
3236       break;
3237     }
3238   }
3239   if (!should_record) {
3240     for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
3241       if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
3242         should_record = true;
3243         break;
3244       }
3245     }
3246   }
3247   if (!should_record) Py_RETURN_NONE;
3248 
3249   string c_op_name = TFE_GetPythonString(op_name);
3250 
3251   PyObject* op_outputs;
3252   bool op_outputs_tuple_created = false;
3253 
3254   if (const auto unused_output_indices =
3255           OpGradientUnusedOutputIndices(c_op_name)) {
3256     if (unused_output_indices->empty()) {
3257       op_outputs = Py_None;
3258     } else {
3259       op_outputs_tuple_created = true;
3260       op_outputs =
3261           CopySequenceSettingIndicesToNull(results, *unused_output_indices);
3262     }
3263   } else {
3264     op_outputs = results;
3265   }
3266 
3267   PyObject* op_inputs;
3268   bool op_inputs_tuple_created = false;
3269 
3270   if (const auto unused_input_indices =
3271           OpGradientUnusedInputIndices(c_op_name)) {
3272     if (unused_input_indices->empty()) {
3273       op_inputs = Py_None;
3274     } else {
3275       op_inputs_tuple_created = true;
3276       op_inputs =
3277           CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
3278     }
3279   } else {
3280     op_inputs = inputs;
3281   }
3282 
3283   tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
3284       [op_name, attrs, inputs, results](
3285           const std::vector<PyObject*>& input_tangents,
3286           std::vector<PyObject*>* output_tangents, bool use_batch) {
3287         return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
3288                                output_tangents, use_batch);
3289       });
3290   tensorflow::eager::ForwardFunction<PyObject>* forward_function;
3291   if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
3292       c_op_name == "If" || c_op_name == "StatelessIf") {
3293     // Control flow contains non-hashable attributes. Handling them in Python is
3294     // a headache, so instead we'll stay as close to GradientTape's handling as
3295     // possible (a null forward function means the accumulator forwards to a
3296     // tape).
3297     //
3298     // This is safe to do since we'll only see control flow when graph building,
3299     // in which case we can rely on pruning.
3300     forward_function = nullptr;
3301   } else {
3302     forward_function = &py_forward_function;
3303   }
3304 
3305   PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
3306 
3307   if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
3308 
3309   TapeSetRecordOperation(
3310       op_name, inputs, results, input_ids, input_dtypes,
3311       [op_name, attrs, num_inputs, op_inputs, op_outputs,
3312        forward_pass_name_scope]() {
3313         Py_INCREF(op_name);
3314         Py_INCREF(attrs);
3315         Py_INCREF(num_inputs);
3316         Py_INCREF(op_inputs);
3317         Py_INCREF(op_outputs);
3318         Py_INCREF(forward_pass_name_scope);
3319         PyBackwardFunction* function = new PyBackwardFunction(
3320             [op_name, attrs, num_inputs, op_inputs, op_outputs,
3321              forward_pass_name_scope](
3322                 PyObject* output_grads,
3323                 const std::vector<int64_t>& unneeded_gradients) {
3324               if (PyErr_Occurred()) {
3325                 return static_cast<PyObject*>(nullptr);
3326               }
3327               tensorflow::Safe_PyObjectPtr skip_input_indices;
3328               if (!unneeded_gradients.empty()) {
3329                 skip_input_indices.reset(
3330                     PyTuple_New(unneeded_gradients.size()));
3331                 for (int i = 0; i < unneeded_gradients.size(); i++) {
3332                   PyTuple_SET_ITEM(
3333                       skip_input_indices.get(), i,
3334                       GetPythonObjectFromInt(unneeded_gradients[i]));
3335                 }
3336               } else {
3337                 Py_INCREF(Py_None);
3338                 skip_input_indices.reset(Py_None);
3339               }
3340               tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3341                   "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3342                   output_grads, skip_input_indices.get(),
3343                   forward_pass_name_scope));
3344 
3345               tensorflow::Safe_PyObjectPtr result(
3346                   PyObject_CallObject(gradient_function, callback_args.get()));
3347 
3348               if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
3349 
3350               return tensorflow::swig::Flatten(result.get());
3351             });
3352         return function;
3353       },
3354       [op_name, attrs, num_inputs, op_inputs, op_outputs,
3355        forward_pass_name_scope](PyBackwardFunction* backward_function) {
3356         Py_DECREF(op_name);
3357         Py_DECREF(attrs);
3358         Py_DECREF(num_inputs);
3359         Py_DECREF(op_inputs);
3360         Py_DECREF(op_outputs);
3361         Py_DECREF(forward_pass_name_scope);
3362 
3363         delete backward_function;
3364       },
3365       forward_function);
3366 
3367   Py_DECREF(num_inputs);
3368   if (op_outputs_tuple_created) Py_DECREF(op_outputs);
3369   if (op_inputs_tuple_created) Py_DECREF(op_inputs);
3370 
3371   if (PyErr_Occurred()) {
3372     return nullptr;
3373   }
3374 
3375   Py_RETURN_NONE;
3376 }
3377 
3378 void MaybeNotifyVariableAccessed(PyObject* input) {
3379   DCHECK(CheckResourceVariable(input));
3380   DCHECK(PyObject_HasAttrString(input, "_trainable"));
3381 
3382   tensorflow::Safe_PyObjectPtr trainable(
3383       PyObject_GetAttrString(input, "_trainable"));
3384   if (trainable.get() == Py_False) return;
3385   TFE_Py_TapeVariableAccessed(input);
3386   TFE_Py_VariableWatcherVariableAccessed(input);
3387 }
3388 
3389 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
3390                     PyObject* input, tensorflow::Safe_PyObjectPtr* output,
3391                     TF_Status* status) {
3392   MaybeNotifyVariableAccessed(input);
3393 
3394   TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
3395   auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
3396   if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3397     return false;
3398 
3399   TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
3400   if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3401     return false;
3402 
3403   // Set dtype
3404   DCHECK(PyObject_HasAttrString(input, "_dtype"));
3405   tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
3406   int value;
3407   if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
3408     return false;
3409   }
3410   TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
3411 
3412   // Get handle
3413   tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
3414   if (!EagerTensor_CheckExact(handle.get())) return false;
3415   TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
3416   if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3417     return false;
3418 
3419   int num_retvals = 1;
3420   TFE_TensorHandle* output_handle;
3421   TFE_Execute(op, &output_handle, &num_retvals, status);
3422   if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3423     return false;
3424 
3425   // Always create the py object (and correctly DECREF it) from the returned
3426   // value, else the data will leak.
3427   output->reset(EagerTensorFromHandle(output_handle));
3428 
3429   // TODO(nareshmodi): Should we run post exec callbacks here?
3430   if (parent_op_exec_info.run_gradient_callback) {
3431     tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
3432     PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
3433 
3434     tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
3435     Py_INCREF(output->get());  // stay alive after since tuple steals.
3436     PyTuple_SET_ITEM(outputs.get(), 0, output->get());
3437 
3438     tensorflow::Safe_PyObjectPtr op_string(
3439         GetPythonObjectFromString("ReadVariableOp"));
3440     if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
3441                         outputs.get())) {
3442       return false;
3443     }
3444   }
3445 
3446   return true;
3447 }
3448 
3449 // Supports 3 cases at the moment:
3450 //  i) input is an EagerTensor.
3451 //  ii) input is a ResourceVariable - in this case, the is_variable param is
3452 //  set to true.
3453 //  iii) input is an arbitrary python list/tuple (note, this handling doesn't
3454 //  support packing).
3455 //
3456 //  NOTE: dtype_hint_getter must *always* return a PyObject that can be
3457 //  decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
3458 //  increfs Py_None).
3459 //
3460 //  NOTE: This function sets a python error directly, and returns false.
3461 //  TF_Status is only passed since we don't want to have to reallocate it.
3462 bool ConvertToTensor(
3463     const FastPathOpExecInfo& op_exec_info, PyObject* input,
3464     tensorflow::Safe_PyObjectPtr* output_handle,
3465     // This gets a hint for this particular input.
3466     const std::function<tensorflow::DataType()>& dtype_hint_getter,
3467     // This sets the dtype after conversion is complete.
3468     const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
3469     TF_Status* status) {
3470   if (EagerTensor_CheckExact(input)) {
3471     Py_INCREF(input);
3472     output_handle->reset(input);
3473     return true;
3474   } else if (CheckResourceVariable(input)) {
3475     return ReadVariableOp(op_exec_info, input, output_handle, status);
3476   }
3477 
3478   // The hint comes from a supposedly similarly typed tensor.
3479   tensorflow::DataType dtype_hint = dtype_hint_getter();
3480 
3481   TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
3482       op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
3483   if (handle == nullptr) {
3484     return tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr);
3485   }
3486 
3487   output_handle->reset(EagerTensorFromHandle(handle));
3488   dtype_setter(
3489       static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
3490 
3491   return true;
3492 }
3493 
3494 // Adds input and type attr to the op, and to the list of flattened
3495 // inputs/attrs.
3496 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
3497                   const bool add_type_attr,
3498                   const tensorflow::OpDef::ArgDef& input_arg,
3499                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
3500                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
3501                   TFE_Op* op, TF_Status* status) {
3502   // py_eager_tensor's ownership is transferred to flattened_inputs if it is
3503   // required, else the object is destroyed and DECREF'd when the object goes
3504   // out of scope in this function.
3505   tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
3506 
3507   if (!ConvertToTensor(
3508           *op_exec_info, input, &py_eager_tensor,
3509           [&]() {
3510             if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
3511               return input_arg.type();
3512             }
3513             return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
3514           },
3515           [&](const tensorflow::DataType dtype) {
3516             op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
3517           },
3518           status)) {
3519     return false;
3520   }
3521 
3522   TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
3523 
3524   if (add_type_attr && !input_arg.type_attr().empty()) {
3525     auto dtype = TFE_TensorHandleDataType(input_handle);
3526     TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
3527     if (flattened_attrs != nullptr) {
3528       flattened_attrs->emplace_back(
3529           GetPythonObjectFromString(input_arg.type_attr()));
3530       flattened_attrs->emplace_back(PyLong_FromLong(dtype));
3531     }
3532   }
3533 
3534   if (flattened_inputs != nullptr) {
3535     flattened_inputs->emplace_back(std::move(py_eager_tensor));
3536   }
3537 
3538   TFE_OpAddInput(op, input_handle, status);
3539   if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3540     return false;
3541   }
3542 
3543   return true;
3544 }
3545 
3546 const char* GetDeviceName(PyObject* py_device_name) {
3547   if (py_device_name != Py_None) {
3548     return TFE_GetPythonString(py_device_name);
3549   }
3550   return nullptr;
3551 }
3552 
3553 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
3554   if (!PySequence_Check(seq)) {
3555     PyErr_SetString(PyExc_TypeError,
3556                     Printf("expected a sequence for attr %s, got %s instead",
3557                            attr_name.data(), seq->ob_type->tp_name)
3558                         .data());
3559 
3560     return false;
3561   }
3562   if (PyArray_Check(seq) &&
3563       PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
3564     PyErr_SetString(PyExc_ValueError,
3565                     Printf("expected a sequence for attr %s, got an ndarray "
3566                            "with rank %d instead",
3567                            attr_name.data(),
3568                            PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
3569                         .data());
3570     return false;
3571   }
3572   return true;
3573 }
3574 
3575 bool RunCallbacks(
3576     const FastPathOpExecInfo& op_exec_info, PyObject* args,
3577     int num_inferred_attrs,
3578     const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
3579     const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
3580     PyObject* flattened_result) {
3581   DCHECK(op_exec_info.run_callbacks);
3582 
3583   tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
3584   for (int i = 0; i < flattened_inputs.size(); i++) {
3585     PyObject* input = flattened_inputs[i].get();
3586     Py_INCREF(input);
3587     PyTuple_SET_ITEM(inputs.get(), i, input);
3588   }
3589 
3590   int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
3591   int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
3592   tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
3593 
3594   for (int i = 0; i < num_non_inferred_attrs; i++) {
3595     auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
3596     Py_INCREF(attr);
3597     PyTuple_SET_ITEM(attrs.get(), i, attr);
3598   }
3599 
3600   for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
3601     PyObject* attr_or_name =
3602         flattened_attrs.at(i - num_non_inferred_attrs).get();
3603     Py_INCREF(attr_or_name);
3604     PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
3605   }
3606 
3607   if (op_exec_info.run_gradient_callback) {
3608     if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
3609                         flattened_result)) {
3610       return false;
3611     }
3612   }
3613 
3614   if (op_exec_info.run_post_exec_callbacks) {
3615     tensorflow::Safe_PyObjectPtr callback_args(
3616         Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
3617                       flattened_result, op_exec_info.name));
3618     for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
3619       PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
3620       if (!PyCallable_Check(callback_fn)) {
3621         PyErr_SetString(
3622             PyExc_TypeError,
3623             Printf("expected a function for "
3624                    "post execution callback in index %ld, got %s instead",
3625                    i, callback_fn->ob_type->tp_name)
3626                 .c_str());
3627         return false;
3628       }
3629       PyObject* callback_result =
3630           PyObject_CallObject(callback_fn, callback_args.get());
3631       if (!callback_result) {
3632         return false;
3633       }
3634       Py_DECREF(callback_result);
3635     }
3636   }
3637 
3638   return true;
3639 }
3640 
3641 }  // namespace
3642 
3643 PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
3644   tensorflow::profiler::TraceMe activity(
3645       "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
3646   Py_ssize_t args_size = PyTuple_GET_SIZE(args);
3647   if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
3648     PyErr_SetString(
3649         PyExc_ValueError,
3650         Printf("There must be at least %d items in the input tuple.",
3651                FAST_PATH_EXECUTE_ARG_INPUT_START)
3652             .c_str());
3653     return nullptr;
3654   }
3655 
3656   FastPathOpExecInfo op_exec_info;
3657 
3658   PyObject* py_eager_context =
3659       PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
3660 
3661   // TODO(edoper): Use interned string here
3662   PyObject* eager_context_handle =
3663       PyObject_GetAttrString(py_eager_context, "_context_handle");
3664 
3665   TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
3666       PyCapsule_GetPointer(eager_context_handle, nullptr));
3667   op_exec_info.ctx = ctx;
3668   op_exec_info.args = args;
3669 
3670   if (ctx == nullptr) {
3671     // The context hasn't been initialized. It will be in the slow path.
3672     RaiseFallbackException(
3673         "This function does not handle the case of the path where "
3674         "all inputs are not already EagerTensors.");
3675     return nullptr;
3676   }
3677 
3678   auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
3679   if (tld == nullptr) {
3680     return nullptr;
3681   }
3682   op_exec_info.device_name = GetDeviceName(tld->device_name.get());
3683   op_exec_info.callbacks = tld->op_callbacks.get();
3684 
3685   op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
3686   op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
3687 
3688   // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
3689   // (similar to benchmark_tf_gradient_function_*). Also consider using an
3690   // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
3691   // point out problems with heap allocs.
3692   op_exec_info.run_gradient_callback =
3693       !*ThreadTapeIsStopped() && HasAccumulatorOrTape();
3694   op_exec_info.run_post_exec_callbacks =
3695       op_exec_info.callbacks != Py_None &&
3696       PyList_Size(op_exec_info.callbacks) > 0;
3697   op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
3698                                op_exec_info.run_post_exec_callbacks;
3699 
3700   TF_Status* status = GetStatus();
3701   const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
3702   if (op_name == nullptr) {
3703     PyErr_SetString(PyExc_TypeError,
3704                     Printf("expected a string for op_name, got %s instead",
3705                            op_exec_info.op_name->ob_type->tp_name)
3706                         .c_str());
3707     return nullptr;
3708   }
3709 
3710   TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
3711 
3712   auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
3713     ReturnStatus(status);
3714     ReturnOp(ctx, op);
3715   });
3716 
3717   if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3718     return nullptr;
3719   }
3720 
3721   tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
3722       tensorflow::StackTrace::kStackTraceInitialSize));
3723 
3724   const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
3725   if (op_def == nullptr) return nullptr;
3726 
3727   if (args_size <
3728       FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
3729     PyErr_SetString(
3730         PyExc_ValueError,
3731         Printf("Tuple size smaller than intended. Expected to be at least %d, "
3732                "was %ld",
3733                FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3734                args_size)
3735             .c_str());
3736     return nullptr;
3737   }
3738 
3739   if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
3740     RaiseFallbackException(
3741         "This function does not handle the case of the path where "
3742         "all inputs are not already EagerTensors.");
3743     return nullptr;
3744   }
3745 
3746   op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
3747   op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
3748 
3749   // Mapping of attr name to size - used to calculate the number of values
3750   // to be expected by the TFE_Execute run.
3751   tensorflow::gtl::FlatMap<string, int64_t> attr_list_sizes;
3752 
3753   // Set non-inferred attrs, including setting defaults if the attr is passed in
3754   // as None.
3755   for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
3756        i < args_size; i += 2) {
3757     PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
3758     const char* attr_name = TFE_GetPythonString(py_attr_name);
3759     PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
3760 
3761     // Not creating an index since most of the time there are not more than a
3762     // few attrs.
3763     // TODO(nareshmodi): Maybe include the index as part of the
3764     // OpRegistrationData.
3765     for (const auto& attr : op_def->attr()) {
3766       if (tensorflow::StringPiece(attr_name) == attr.name()) {
3767         SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
3768                               &attr_list_sizes, status);
3769 
3770         if (!status->status.ok()) {
3771           VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
3772                   << "\" since we are unable to set the value for attr \""
3773                   << attr.name() << "\" due to: " << TF_Message(status);
3774           RaiseFallbackException(TF_Message(status));
3775           return nullptr;
3776         }
3777 
3778         break;
3779       }
3780     }
3781   }
3782 
3783   // Flat attrs and inputs as required by the record_gradient call. The attrs
3784   // here only contain inferred attrs (non-inferred attrs are added directly
3785   // from the input args).
3786   // All items in flattened_attrs and flattened_inputs contain
3787   // Safe_PyObjectPtr - any time something steals a reference to this, it must
3788   // INCREF.
3789   // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
3790   // directly.
3791   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
3792       nullptr;
3793   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
3794       nullptr;
3795 
3796   // TODO(nareshmodi): Encapsulate callbacks information into a struct.
3797   if (op_exec_info.run_callbacks) {
3798     flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3799     flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3800   }
3801 
3802   // Add inferred attrs and inputs.
3803   // The following code might set duplicate type attrs. This will result in
3804   // the CacheKey for the generated AttrBuilder possibly differing from
3805   // those where the type attrs are correctly set. Inconsistent CacheKeys
3806   // for ops means that there might be unnecessarily duplicated kernels.
3807   // TODO(nareshmodi): Fix this.
3808   for (int i = 0; i < op_def->input_arg_size(); i++) {
3809     const auto& input_arg = op_def->input_arg(i);
3810 
3811     PyObject* input =
3812         PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
3813     if (!input_arg.number_attr().empty()) {
3814       // The item is a homogeneous list.
3815       if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
3816       tensorflow::Safe_PyObjectPtr fast_input(
3817           PySequence_Fast(input, "Could not parse sequence."));
3818       if (fast_input.get() == nullptr) {
3819         return nullptr;
3820       }
3821       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3822       PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3823 
3824       TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
3825       if (op_exec_info.run_callbacks) {
3826         flattened_attrs->emplace_back(
3827             GetPythonObjectFromString(input_arg.number_attr()));
3828         flattened_attrs->emplace_back(PyLong_FromLong(len));
3829       }
3830       attr_list_sizes[input_arg.number_attr()] = len;
3831 
3832       if (len > 0) {
3833         // First item adds the type attr.
3834         if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
3835                           flattened_attrs.get(), flattened_inputs.get(), op,
3836                           status)) {
3837           return nullptr;
3838         }
3839 
3840         for (Py_ssize_t j = 1; j < len; j++) {
3841           // Since the list is homogeneous, we don't need to re-add the attr.
3842           if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
3843                             input_arg, nullptr /* flattened_attrs */,
3844                             flattened_inputs.get(), op, status)) {
3845             return nullptr;
3846           }
3847         }
3848       }
3849     } else if (!input_arg.type_list_attr().empty()) {
3850       // The item is a heterogeneous list.
3851       if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
3852         return nullptr;
3853       }
3854       tensorflow::Safe_PyObjectPtr fast_input(
3855           PySequence_Fast(input, "Could not parse sequence."));
3856       if (fast_input.get() == nullptr) {
3857         return nullptr;
3858       }
3859       const string& attr_name = input_arg.type_list_attr();
3860       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3861       PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3862       tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
3863       PyObject* py_attr_value = nullptr;
3864       if (op_exec_info.run_callbacks) {
3865         py_attr_value = PyTuple_New(len);
3866       }
3867       for (Py_ssize_t j = 0; j < len; j++) {
3868         PyObject* py_input = fast_input_array[j];
3869         tensorflow::Safe_PyObjectPtr py_eager_tensor;
3870         if (!ConvertToTensor(
3871                 op_exec_info, py_input, &py_eager_tensor,
3872                 []() { return tensorflow::DT_INVALID; },
3873                 [](const tensorflow::DataType dtype) {}, status)) {
3874           return nullptr;
3875         }
3876 
3877         TFE_TensorHandle* input_handle =
3878             EagerTensor_Handle(py_eager_tensor.get());
3879 
3880         attr_value[j] = TFE_TensorHandleDataType(input_handle);
3881 
3882         TFE_OpAddInput(op, input_handle, status);
3883         if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3884           return nullptr;
3885         }
3886 
3887         if (op_exec_info.run_callbacks) {
3888           flattened_inputs->emplace_back(std::move(py_eager_tensor));
3889 
3890           PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
3891         }
3892       }
3893       if (op_exec_info.run_callbacks) {
3894         flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
3895         flattened_attrs->emplace_back(py_attr_value);
3896       }
3897       TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
3898                             attr_value.size());
3899       attr_list_sizes[attr_name] = len;
3900     } else {
3901       // The item is a single item.
3902       if (!AddInputToOp(&op_exec_info, input, true, input_arg,
3903                         flattened_attrs.get(), flattened_inputs.get(), op,
3904                         status)) {
3905         return nullptr;
3906       }
3907     }
3908   }
3909 
3910   int64_t num_outputs = 0;
3911   for (int i = 0; i < op_def->output_arg_size(); i++) {
3912     const auto& output_arg = op_def->output_arg(i);
3913     int64_t delta = 1;
3914     if (!output_arg.number_attr().empty()) {
3915       delta = attr_list_sizes[output_arg.number_attr()];
3916     } else if (!output_arg.type_list_attr().empty()) {
3917       delta = attr_list_sizes[output_arg.type_list_attr()];
3918     }
3919     if (delta < 0) {
3920       RaiseFallbackException(
3921           "Attributes suggest that the size of an output list is less than 0");
3922       return nullptr;
3923     }
3924     num_outputs += delta;
3925   }
3926 
3927   // If number of retvals is larger than int32, we error out.
3928   if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
3929     PyErr_SetString(
3930         PyExc_ValueError,
3931         Printf("Number of outputs is too big: %ld", num_outputs).c_str());
3932     return nullptr;
3933   }
3934   int num_retvals = num_outputs;
3935 
3936   tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
3937 
3938   Py_BEGIN_ALLOW_THREADS;
3939   TFE_Execute(op, retvals.data(), &num_retvals, status);
3940   Py_END_ALLOW_THREADS;
3941 
3942   if (!status->status.ok()) {
3943     // Augment the status with the op_name for easier debugging similar to
3944     // TFE_Py_Execute.
3945     status->status = tensorflow::errors::CreateWithUpdatedMessage(
3946         status->status, tensorflow::strings::StrCat(
3947                             TF_Message(status), " [Op:",
3948                             TFE_GetPythonString(op_exec_info.op_name), "]"));
3949     tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr);
3950     return nullptr;
3951   }
3952 
3953   tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
3954   for (int i = 0; i < num_retvals; ++i) {
3955     PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
3956   }
3957 
3958   if (op_exec_info.run_callbacks) {
3959     if (!RunCallbacks(
3960             op_exec_info, args,
3961             FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3962             *flattened_inputs, *flattened_attrs, flat_result.get())) {
3963       return nullptr;
3964     }
3965   }
3966 
3967   // Unflatten results.
3968   if (op_def->output_arg_size() == 0) {
3969     Py_RETURN_NONE;
3970   }
3971 
3972   if (op_def->output_arg_size() == 1) {
3973     if (!op_def->output_arg(0).number_attr().empty() ||
3974         !op_def->output_arg(0).type_list_attr().empty()) {
3975       return flat_result.release();
3976     } else {
3977       auto* result = PyList_GET_ITEM(flat_result.get(), 0);
3978       Py_INCREF(result);
3979       return result;
3980     }
3981   }
3982 
3983   // Correctly output the results that are made into a namedtuple.
3984   PyObject* result = PyList_New(op_def->output_arg_size());
3985   int flat_result_index = 0;
3986   for (int i = 0; i < op_def->output_arg_size(); i++) {
3987     if (!op_def->output_arg(i).number_attr().empty()) {
3988       int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
3989       PyObject* inner_list = PyList_New(list_length);
3990       for (int j = 0; j < list_length; j++) {
3991         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3992         Py_INCREF(obj);
3993         PyList_SET_ITEM(inner_list, j, obj);
3994       }
3995       PyList_SET_ITEM(result, i, inner_list);
3996     } else if (!op_def->output_arg(i).type_list_attr().empty()) {
3997       int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
3998       PyObject* inner_list = PyList_New(list_length);
3999       for (int j = 0; j < list_length; j++) {
4000         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
4001         Py_INCREF(obj);
4002         PyList_SET_ITEM(inner_list, j, obj);
4003       }
4004       PyList_SET_ITEM(result, i, inner_list);
4005     } else {
4006       PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
4007       Py_INCREF(obj);
4008       PyList_SET_ITEM(result, i, obj);
4009     }
4010   }
4011   return result;
4012 }
4013 
4014 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
4015                                 PyObject* attrs, PyObject* results,
4016                                 PyObject* forward_pass_name_scope) {
4017   if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
4018     Py_RETURN_NONE;
4019   }
4020 
4021   return RecordGradient(op_name, inputs, attrs, results,
4022                         forward_pass_name_scope);
4023 }
4024 
4025 // A method prints incoming messages directly to Python's
4026 // stdout using Python's C API. This is necessary in Jupyter notebooks
4027 // and colabs where messages to the C stdout don't go to the notebook
4028 // cell outputs, but calls to Python's stdout do.
4029 void PrintToPythonStdout(const char* msg) {
4030   if (Py_IsInitialized()) {
4031     PyGILState_STATE py_threadstate;
4032     py_threadstate = PyGILState_Ensure();
4033 
4034     string string_msg = msg;
4035     // PySys_WriteStdout truncates strings over 1000 bytes, so
4036     // we write the message in chunks small enough to not be truncated.
4037     int CHUNK_SIZE = 900;
4038     auto len = string_msg.length();
4039     for (int i = 0; i < len; i += CHUNK_SIZE) {
4040       PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
4041     }
4042 
4043     // Force flushing to make sure print newlines aren't interleaved in
4044     // some colab environments
4045     PyRun_SimpleString("import sys; sys.stdout.flush()");
4046 
4047     PyGILState_Release(py_threadstate);
4048   }
4049 }
4050 
4051 // Register PrintToPythonStdout as a log listener, to allow
4052 // printing in colabs and jupyter notebooks to work.
4053 void TFE_Py_EnableInteractivePythonLogging() {
4054   static bool enabled_interactive_logging = false;
4055   if (!enabled_interactive_logging) {
4056     enabled_interactive_logging = true;
4057     TF_RegisterLogListener(PrintToPythonStdout);
4058   }
4059 }
4060 
4061 namespace {
4062 // TODO(mdan): Clean this. Maybe by decoupling context lifetime from Python GC?
4063 // Weak reference to the Python Context (see tensorflow/python/eager/context.py)
4064 // object currently active. This object is opaque and wrapped inside a Python
4065 // Capsule. However, the EagerContext object it holds is tracked by the
4066 // global_c_eager_context object.
4067 // Also see common_runtime/eager/context.cc.
4068 PyObject* global_py_eager_context = nullptr;
4069 }  // namespace
4070 
4071 PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
4072   Py_XDECREF(global_py_eager_context);
4073   global_py_eager_context = PyWeakref_NewRef(py_context, nullptr);
4074   if (global_py_eager_context == nullptr) {
4075     return nullptr;
4076   }
4077   Py_RETURN_NONE;
4078 }
4079 
4080 PyObject* GetPyEagerContext() {
4081   if (global_py_eager_context == nullptr) {
4082     PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
4083     return nullptr;
4084   }
4085   PyObject* py_context = PyWeakref_GET_OBJECT(global_py_eager_context);
4086   if (py_context == Py_None) {
4087     PyErr_SetString(PyExc_RuntimeError,
4088                     "Python eager context has been destroyed");
4089     return nullptr;
4090   }
4091   Py_INCREF(py_context);
4092   return py_context;
4093 }
4094 
4095 namespace {
4096 
4097 // Default values for thread_local_data fields.
4098 struct EagerContextThreadLocalDataDefaults {
4099   tensorflow::Safe_PyObjectPtr is_eager;
4100   tensorflow::Safe_PyObjectPtr device_spec;
4101 };
4102 
4103 // Maps each py_eager_context object to its thread_local_data.
4104 //
4105 // Note: we need to use the python Context object as the key here (and not
4106 // its handle object), because the handle object isn't created until the
4107 // context is initialized; but thread_local_data is potentially accessed
4108 // before then.
4109 using EagerContextThreadLocalDataMap = absl::flat_hash_map<
4110     PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
4111 thread_local EagerContextThreadLocalDataMap*
4112     eager_context_thread_local_data_map = nullptr;
4113 
4114 // Maps each py_eager_context object to default values.
4115 using EagerContextThreadLocalDataDefaultsMap =
4116     absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
4117 EagerContextThreadLocalDataDefaultsMap*
4118     eager_context_thread_local_data_defaults = nullptr;
4119 
4120 }  // namespace
4121 
4122 namespace tensorflow {
4123 
4124 void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
4125                                      PyObject* is_eager,
4126                                      PyObject* device_spec) {
4127   DCheckPyGilState();
4128   if (eager_context_thread_local_data_defaults == nullptr) {
4129     absl::LeakCheckDisabler disabler;
4130     eager_context_thread_local_data_defaults =
4131         new EagerContextThreadLocalDataDefaultsMap();
4132   }
4133   if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
4134     PyErr_SetString(PyExc_AssertionError,
4135                     "MakeEagerContextThreadLocalData may not be called "
4136                     "twice on the same eager Context object.");
4137   }
4138 
4139   auto& defaults =
4140       (*eager_context_thread_local_data_defaults)[py_eager_context];
4141   Py_INCREF(is_eager);
4142   defaults.is_eager.reset(is_eager);
4143   Py_INCREF(device_spec);
4144   defaults.device_spec.reset(device_spec);
4145 }
4146 
4147 EagerContextThreadLocalData* GetEagerContextThreadLocalData(
4148     PyObject* py_eager_context) {
4149   if (eager_context_thread_local_data_defaults == nullptr) {
4150     PyErr_SetString(PyExc_AssertionError,
4151                     "MakeEagerContextThreadLocalData must be called "
4152                     "before GetEagerContextThreadLocalData.");
4153     return nullptr;
4154   }
4155   auto defaults =
4156       eager_context_thread_local_data_defaults->find(py_eager_context);
4157   if (defaults == eager_context_thread_local_data_defaults->end()) {
4158     PyErr_SetString(PyExc_AssertionError,
4159                     "MakeEagerContextThreadLocalData must be called "
4160                     "before GetEagerContextThreadLocalData.");
4161     return nullptr;
4162   }
4163 
4164   if (eager_context_thread_local_data_map == nullptr) {
4165     absl::LeakCheckDisabler disabler;
4166     eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
4167   }
4168   auto& thread_local_data =
4169       (*eager_context_thread_local_data_map)[py_eager_context];
4170 
4171   if (!thread_local_data) {
4172     thread_local_data.reset(new EagerContextThreadLocalData());
4173 
4174     Safe_PyObjectPtr is_eager(
4175         PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr));
4176     if (!is_eager) return nullptr;
4177     thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
4178 
4179 #if PY_MAJOR_VERSION >= 3
4180     PyObject* scope_name = PyUnicode_FromString("");
4181 #else
4182     PyObject* scope_name = PyString_FromString("");
4183 #endif
4184     thread_local_data->scope_name.reset(scope_name);
4185 
4186 #if PY_MAJOR_VERSION >= 3
4187     PyObject* device_name = PyUnicode_FromString("");
4188 #else
4189     PyObject* device_name = PyString_FromString("");
4190 #endif
4191     thread_local_data->device_name.reset(device_name);
4192 
4193     Py_INCREF(defaults->second.device_spec.get());
4194     thread_local_data->device_spec.reset(defaults->second.device_spec.get());
4195 
4196     Py_INCREF(Py_None);
4197     thread_local_data->function_call_options.reset(Py_None);
4198 
4199     Py_INCREF(Py_None);
4200     thread_local_data->executor.reset(Py_None);
4201 
4202     thread_local_data->op_callbacks.reset(PyList_New(0));
4203   }
4204   return thread_local_data.get();
4205 }
4206 
4207 void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
4208   DCheckPyGilState();
4209   if (eager_context_thread_local_data_defaults) {
4210     eager_context_thread_local_data_defaults->erase(py_eager_context);
4211   }
4212   if (eager_context_thread_local_data_map) {
4213     eager_context_thread_local_data_map->erase(py_eager_context);
4214   }
4215 }
4216 
4217 }  // namespace tensorflow
4218