1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <atomic>
17 #include <cstring>
18 #include <unordered_map>
19
20 #include "absl/debugging/leak_check.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_replace.h"
23 #include "absl/types/variant.h"
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/c/c_api_internal.h"
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/c/eager/tape.h"
29 #include "tensorflow/c/eager/tfe_context_internal.h"
30 #include "tensorflow/c/eager/tfe_op_internal.h"
31 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
32 #include "tensorflow/c/tf_status.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/gtl/compactptrset.h"
37 #include "tensorflow/core/lib/gtl/flatmap.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/lib/strings/stringprintf.h"
41 #include "tensorflow/core/platform/casts.h"
42 #include "tensorflow/core/platform/errors.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/protobuf.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/platform/statusor.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/profiler/lib/traceme.h"
49 #include "tensorflow/core/util/managed_stack_trace.h"
50 #include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
51 #include "tensorflow/python/eager/pywrap_tensor.h"
52 #include "tensorflow/python/eager/pywrap_tfe.h"
53 #include "tensorflow/python/lib/core/py_util.h"
54 #include "tensorflow/python/lib/core/safe_ptr.h"
55 #include "tensorflow/python/util/stack_trace.h"
56 #include "tensorflow/python/util/util.h"
57
58 using tensorflow::Status;
59 using tensorflow::string;
60 using tensorflow::strings::Printf;
61
62 namespace {
63 // NOTE: Items are retrieved from and returned to these unique_ptrs, and they
64 // act as arenas. This is important if the same thread requests 2 items without
65 // releasing one.
66 // The following sequence of events on the same thread will still succeed:
67 // - GetOp <- Returns existing.
68 // - GetOp <- Allocates and returns a new pointer.
69 // - ReleaseOp <- Sets the item in the unique_ptr.
70 // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
71 // This occurs when a PyFunc kernel is run. This behavior makes it safe in that
72 // case, as well as the case where python decides to reuse the underlying
73 // C++ thread in 2 python threads case.
74 struct OpDeleter {
operator ()__anon0e39441a0111::OpDeleter75 void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
76 };
77 thread_local std::unordered_map<TFE_Context*,
78 std::unique_ptr<TFE_Op, OpDeleter>>
79 thread_local_eager_operation_map; // NOLINT
80 thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
81 nullptr;
82
ReleaseThreadLocalOp(TFE_Context * ctx)83 std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
84 auto it = thread_local_eager_operation_map.find(ctx);
85 if (it == thread_local_eager_operation_map.end()) {
86 return nullptr;
87 }
88 return std::move(it->second);
89 }
90
GetOp(TFE_Context * ctx,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)91 TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
92 const char* raw_device_name, TF_Status* status) {
93 auto op = ReleaseThreadLocalOp(ctx);
94 if (!op) {
95 op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
96 }
97 status->status =
98 tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
99 if (!status->status.ok()) {
100 op.reset();
101 }
102 return op.release();
103 }
104
ReturnOp(TFE_Context * ctx,TFE_Op * op)105 void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
106 if (op) {
107 tensorflow::unwrap(op)->Clear();
108 thread_local_eager_operation_map[ctx].reset(op);
109 }
110 }
111
ReleaseThreadLocalStatus()112 TF_Status* ReleaseThreadLocalStatus() {
113 if (thread_local_tf_status == nullptr) {
114 return nullptr;
115 }
116 return thread_local_tf_status.release();
117 }
118
119 struct InputInfo {
InputInfo__anon0e39441a0111::InputInfo120 InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
121
122 int i;
123 bool is_list = false;
124 };
125
126 // Takes in output gradients, returns input gradients.
127 typedef std::function<PyObject*(PyObject*, const std::vector<int64_t>&)>
128 PyBackwardFunction;
129
130 using AttrToInputsMap =
131 tensorflow::gtl::FlatMap<string,
132 tensorflow::gtl::InlinedVector<InputInfo, 4>>;
133
GetAllAttrToInputsMaps()134 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
135 static auto* all_attr_to_input_maps =
136 new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
137 return all_attr_to_input_maps;
138 }
139
140 // This function doesn't use a lock, since we depend on the GIL directly.
GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef & op_def)141 AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
142 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
143 DCHECK(PyGILState_Check())
144 << "This function needs to hold the GIL when called.";
145 #endif
146 auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
147 auto* output =
148 tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
149 if (output != nullptr) {
150 return output;
151 }
152
153 std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
154
155 // Store a list of InputIndex -> List of corresponding inputs.
156 for (int i = 0; i < op_def.input_arg_size(); i++) {
157 if (!op_def.input_arg(i).type_attr().empty()) {
158 auto it = m->find(op_def.input_arg(i).type_attr());
159 if (it == m->end()) {
160 it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
161 }
162 it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
163 }
164 }
165
166 auto* retval = m.get();
167 (*all_attr_to_input_maps)[op_def.name()] = m.release();
168
169 return retval;
170 }
171
172 // This function doesn't use a lock, since we depend on the GIL directly.
173 tensorflow::gtl::FlatMap<
174 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
GetAllAttrToDefaultsMaps()175 GetAllAttrToDefaultsMaps() {
176 static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
177 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
178 return all_attr_to_defaults_maps;
179 }
180
181 tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef & op_def)182 GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
183 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
184 DCHECK(PyGILState_Check())
185 << "This function needs to hold the GIL when called.";
186 #endif
187 auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
188 auto* output =
189 tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
190 if (output != nullptr) {
191 return output;
192 }
193
194 auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
195
196 for (const auto& attr : op_def.attr()) {
197 if (attr.type() == "type" && attr.has_default_value()) {
198 new_map->insert({attr.name(), attr.default_value().type()});
199 }
200 }
201
202 (*all_attr_to_defaults_maps)[op_def.name()] = new_map;
203
204 return new_map;
205 }
206
207 struct FastPathOpExecInfo {
208 TFE_Context* ctx;
209 const char* device_name;
210
211 bool run_callbacks;
212 bool run_post_exec_callbacks;
213 bool run_gradient_callback;
214
215 // The op name of the main op being executed.
216 PyObject* name;
217 // The op type name of the main op being executed.
218 PyObject* op_name;
219 PyObject* callbacks;
220
221 // All the args passed into the FastPathOpExecInfo.
222 PyObject* args;
223
224 // DTypes can come from another input that has the same attr. So build that
225 // map.
226 const AttrToInputsMap* attr_to_inputs_map;
227 const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
228 tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
229 };
230
231 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
232 bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
233 type* value) { \
234 if (check_fn(py_value)) { \
235 *value = static_cast<type>(parse_fn(py_value)); \
236 return true; \
237 } else { \
238 TF_SetStatus(status, TF_INVALID_ARGUMENT, \
239 tensorflow::strings::StrCat( \
240 "Expecting " #type " value for attr ", key, ", got ", \
241 py_value->ob_type->tp_name) \
242 .c_str()); \
243 return false; \
244 } \
245 }
246
247 #if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue,int,PyLong_Check,PyLong_AsLong)248 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
249 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
250 #else
251 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
252 #endif
253 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
254 #undef PARSE_VALUE
255
256 #if PY_MAJOR_VERSION < 3
257 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
258 int64_t* value) {
259 if (PyInt_Check(py_value)) {
260 *value = static_cast<int64_t>(PyInt_AsLong(py_value));
261 return true;
262 } else if (PyLong_Check(py_value)) {
263 *value = static_cast<int64_t>(PyLong_AsLong(py_value));
264 return true;
265 }
266 TF_SetStatus(
267 status, TF_INVALID_ARGUMENT,
268 tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
269 ", got ", py_value->ob_type->tp_name)
270 .c_str());
271 return false;
272 }
273 #endif
274
TensorShapeNumDims(PyObject * value)275 Py_ssize_t TensorShapeNumDims(PyObject* value) {
276 const auto size = PySequence_Size(value);
277 if (size == -1) {
278 // TensorShape.__len__ raises an error in the scenario where the shape is an
279 // unknown, which needs to be cleared.
280 // TODO(nareshmodi): ensure that this is actually a TensorShape.
281 PyErr_Clear();
282 }
283 return size;
284 }
285
IsInteger(PyObject * py_value)286 bool IsInteger(PyObject* py_value) {
287 #if PY_MAJOR_VERSION >= 3
288 return PyLong_Check(py_value);
289 #else
290 return PyInt_Check(py_value) || PyLong_Check(py_value);
291 #endif
292 }
293
294 // This function considers a Dimension._value of None to be valid, and sets the
295 // value to be -1 in that case.
ParseDimensionValue(const string & key,PyObject * py_value,TF_Status * status,int64_t * value)296 bool ParseDimensionValue(const string& key, PyObject* py_value,
297 TF_Status* status, int64_t* value) {
298 if (IsInteger(py_value)) {
299 return ParseInt64Value(key, py_value, status, value);
300 }
301
302 tensorflow::Safe_PyObjectPtr dimension_value(
303 PyObject_GetAttrString(py_value, "_value"));
304 if (dimension_value == nullptr) {
305 PyErr_Clear();
306 TF_SetStatus(
307 status, TF_INVALID_ARGUMENT,
308 tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
309 ", got ", py_value->ob_type->tp_name)
310 .c_str());
311 return false;
312 }
313
314 if (dimension_value.get() == Py_None) {
315 *value = -1;
316 return true;
317 }
318
319 return ParseInt64Value(key, dimension_value.get(), status, value);
320 }
321
ParseStringValue(const string & key,PyObject * py_value,TF_Status * status,tensorflow::StringPiece * value)322 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
323 tensorflow::StringPiece* value) {
324 if (PyBytes_Check(py_value)) {
325 Py_ssize_t size = 0;
326 char* buf = nullptr;
327 if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
328 *value = tensorflow::StringPiece(buf, size);
329 return true;
330 }
331 #if PY_MAJOR_VERSION >= 3
332 if (PyUnicode_Check(py_value)) {
333 Py_ssize_t size = 0;
334 const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
335 if (buf == nullptr) return false;
336 *value = tensorflow::StringPiece(buf, size);
337 return true;
338 }
339 #endif
340 TF_SetStatus(
341 status, TF_INVALID_ARGUMENT,
342 tensorflow::strings::StrCat("Expecting a string value for attr ", key,
343 ", got ", py_value->ob_type->tp_name)
344 .c_str());
345 return false;
346 }
347
ParseBoolValue(const string & key,PyObject * py_value,TF_Status * status,unsigned char * value)348 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
349 unsigned char* value) {
350 if (PyBool_Check(py_value)) {
351 *value = PyObject_IsTrue(py_value);
352 return true;
353 }
354 TF_SetStatus(
355 status, TF_INVALID_ARGUMENT,
356 tensorflow::strings::StrCat("Expecting bool value for attr ", key,
357 ", got ", py_value->ob_type->tp_name)
358 .c_str());
359 return false;
360 }
361
362 // The passed in py_value is expected to be an object of the python type
363 // dtypes.DType or an int.
ParseTypeValue(const string & key,PyObject * py_value,TF_Status * status,int * value)364 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
365 int* value) {
366 if (IsInteger(py_value)) {
367 return ParseIntValue(key, py_value, status, value);
368 }
369
370 tensorflow::Safe_PyObjectPtr py_type_enum(
371 PyObject_GetAttrString(py_value, "_type_enum"));
372 if (py_type_enum == nullptr) {
373 PyErr_Clear();
374 TF_SetStatus(
375 status, TF_INVALID_ARGUMENT,
376 tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
377 ", got ", py_value->ob_type->tp_name)
378 .c_str());
379 return false;
380 }
381
382 return ParseIntValue(key, py_type_enum.get(), status, value);
383 }
384
SetOpAttrList(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_list,TF_AttrType type,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)385 bool SetOpAttrList(TFE_Context* ctx, TFE_Op* op, const char* key,
386 PyObject* py_list, TF_AttrType type,
387 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
388 TF_Status* status) {
389 if (!PySequence_Check(py_list)) {
390 TF_SetStatus(
391 status, TF_INVALID_ARGUMENT,
392 tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
393 ", got ", py_list->ob_type->tp_name)
394 .c_str());
395 return false;
396 }
397 const int num_values = PySequence_Size(py_list);
398 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
399
400 #define PARSE_LIST(c_type, parse_fn) \
401 std::unique_ptr<c_type[]> values(new c_type[num_values]); \
402 for (int i = 0; i < num_values; ++i) { \
403 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
404 if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
405 }
406
407 if (type == TF_ATTR_STRING) {
408 std::unique_ptr<const void*[]> values(new const void*[num_values]);
409 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
410 for (int i = 0; i < num_values; ++i) {
411 tensorflow::StringPiece value;
412 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
413 if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
414 values[i] = value.data();
415 lengths[i] = value.size();
416 }
417 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
418 } else if (type == TF_ATTR_INT) {
419 PARSE_LIST(int64_t, ParseInt64Value);
420 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
421 } else if (type == TF_ATTR_FLOAT) {
422 PARSE_LIST(float, ParseFloatValue);
423 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
424 } else if (type == TF_ATTR_BOOL) {
425 PARSE_LIST(unsigned char, ParseBoolValue);
426 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
427 } else if (type == TF_ATTR_TYPE) {
428 PARSE_LIST(int, ParseTypeValue);
429 TFE_OpSetAttrTypeList(op, key,
430 reinterpret_cast<const TF_DataType*>(values.get()),
431 num_values);
432 } else if (type == TF_ATTR_SHAPE) {
433 // Make one pass through the input counting the total number of
434 // dims across all the input lists.
435 int total_dims = 0;
436 for (int i = 0; i < num_values; ++i) {
437 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
438 if (py_value.get() != Py_None) {
439 if (!PySequence_Check(py_value.get())) {
440 TF_SetStatus(
441 status, TF_INVALID_ARGUMENT,
442 tensorflow::strings::StrCat(
443 "Expecting None or sequence value for element", i,
444 " of attr ", key, ", got ", py_value->ob_type->tp_name)
445 .c_str());
446 return false;
447 }
448 const auto size = TensorShapeNumDims(py_value.get());
449 if (size >= 0) {
450 total_dims += size;
451 }
452 }
453 }
454 // Allocate a buffer that can fit all of the dims together.
455 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
456 // Copy the input dims into the buffer and set dims to point to
457 // the start of each list's dims.
458 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
459 std::unique_ptr<int[]> num_dims(new int[num_values]);
460 int64_t* offset = buffer.get();
461 for (int i = 0; i < num_values; ++i) {
462 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
463 if (py_value.get() == Py_None) {
464 dims[i] = nullptr;
465 num_dims[i] = -1;
466 } else {
467 const auto size = TensorShapeNumDims(py_value.get());
468 if (size == -1) {
469 dims[i] = nullptr;
470 num_dims[i] = -1;
471 continue;
472 }
473 dims[i] = offset;
474 num_dims[i] = size;
475 for (int j = 0; j < size; ++j) {
476 tensorflow::Safe_PyObjectPtr inner_py_value(
477 PySequence_ITEM(py_value.get(), j));
478 if (inner_py_value.get() == Py_None) {
479 *offset = -1;
480 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
481 offset)) {
482 return false;
483 }
484 ++offset;
485 }
486 }
487 }
488 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
489 status);
490 if (!status->status.ok()) return false;
491 } else if (type == TF_ATTR_FUNC) {
492 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
493 for (int i = 0; i < num_values; ++i) {
494 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
495 // Allow:
496 // (1) String function name, OR
497 // (2) A Python object with a .name attribute
498 // (A crude test for being a
499 // tensorflow.python.framework.function._DefinedFunction)
500 // (which is what the various "defun" or "Defun" decorators do).
501 // And in the future also allow an object that can encapsulate
502 // the function name and its attribute values.
503 tensorflow::StringPiece func_name;
504 if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
505 PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
506 if (name_attr == nullptr ||
507 !ParseStringValue(key, name_attr, status, &func_name)) {
508 TF_SetStatus(
509 status, TF_INVALID_ARGUMENT,
510 tensorflow::strings::StrCat(
511 "unable to set function value attribute from a ",
512 py_value.get()->ob_type->tp_name,
513 " object. If you think this is an error, please file an "
514 "issue at "
515 "https://github.com/tensorflow/tensorflow/issues/new")
516 .c_str());
517 return false;
518 }
519 }
520 funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
521 if (!status->status.ok()) return false;
522 }
523 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
524 if (!status->status.ok()) return false;
525 } else {
526 TF_SetStatus(status, TF_UNIMPLEMENTED,
527 tensorflow::strings::StrCat("Attr ", key,
528 " has unhandled list type ", type)
529 .c_str());
530 return false;
531 }
532 #undef PARSE_LIST
533 return true;
534 }
535
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)536 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
537 TF_Status* status) {
538 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
539 for (const auto& attr : func.attr()) {
540 if (!status->status.ok()) return nullptr;
541 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
542 if (!status->status.ok()) return nullptr;
543 }
544 return func_op;
545 }
546
SetOpAttrListDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * key,TF_AttrType type,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)547 void SetOpAttrListDefault(
548 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
549 const char* key, TF_AttrType type,
550 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
551 TF_Status* status) {
552 if (type == TF_ATTR_STRING) {
553 int num_values = attr.default_value().list().s_size();
554 std::unique_ptr<const void*[]> values(new const void*[num_values]);
555 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
556 (*attr_list_sizes)[key] = num_values;
557 for (int i = 0; i < num_values; i++) {
558 const string& v = attr.default_value().list().s(i);
559 values[i] = v.data();
560 lengths[i] = v.size();
561 }
562 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
563 } else if (type == TF_ATTR_INT) {
564 int num_values = attr.default_value().list().i_size();
565 std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
566 (*attr_list_sizes)[key] = num_values;
567 for (int i = 0; i < num_values; i++) {
568 values[i] = attr.default_value().list().i(i);
569 }
570 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
571 } else if (type == TF_ATTR_FLOAT) {
572 int num_values = attr.default_value().list().f_size();
573 std::unique_ptr<float[]> values(new float[num_values]);
574 (*attr_list_sizes)[key] = num_values;
575 for (int i = 0; i < num_values; i++) {
576 values[i] = attr.default_value().list().f(i);
577 }
578 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
579 } else if (type == TF_ATTR_BOOL) {
580 int num_values = attr.default_value().list().b_size();
581 std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
582 (*attr_list_sizes)[key] = num_values;
583 for (int i = 0; i < num_values; i++) {
584 values[i] = attr.default_value().list().b(i);
585 }
586 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
587 } else if (type == TF_ATTR_TYPE) {
588 int num_values = attr.default_value().list().type_size();
589 std::unique_ptr<int[]> values(new int[num_values]);
590 (*attr_list_sizes)[key] = num_values;
591 for (int i = 0; i < num_values; i++) {
592 values[i] = attr.default_value().list().type(i);
593 }
594 TFE_OpSetAttrTypeList(op, key,
595 reinterpret_cast<const TF_DataType*>(values.get()),
596 attr.default_value().list().type_size());
597 } else if (type == TF_ATTR_SHAPE) {
598 int num_values = attr.default_value().list().shape_size();
599 (*attr_list_sizes)[key] = num_values;
600 int total_dims = 0;
601 for (int i = 0; i < num_values; ++i) {
602 if (!attr.default_value().list().shape(i).unknown_rank()) {
603 total_dims += attr.default_value().list().shape(i).dim_size();
604 }
605 }
606 // Allocate a buffer that can fit all of the dims together.
607 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
608 // Copy the input dims into the buffer and set dims to point to
609 // the start of each list's dims.
610 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
611 std::unique_ptr<int[]> num_dims(new int[num_values]);
612 int64_t* offset = buffer.get();
613 for (int i = 0; i < num_values; ++i) {
614 const auto& shape = attr.default_value().list().shape(i);
615 if (shape.unknown_rank()) {
616 dims[i] = nullptr;
617 num_dims[i] = -1;
618 } else {
619 for (int j = 0; j < shape.dim_size(); j++) {
620 *offset = shape.dim(j).size();
621 ++offset;
622 }
623 }
624 }
625 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
626 status);
627 } else if (type == TF_ATTR_FUNC) {
628 int num_values = attr.default_value().list().func_size();
629 (*attr_list_sizes)[key] = num_values;
630 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
631 for (int i = 0; i < num_values; i++) {
632 funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
633 }
634 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
635 } else {
636 TF_SetStatus(status, TF_UNIMPLEMENTED,
637 "Lists of tensors are not yet implemented for default valued "
638 "attributes for an operation.");
639 }
640 }
641
SetOpAttrScalar(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_value,TF_AttrType type,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)642 bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
643 PyObject* py_value, TF_AttrType type,
644 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
645 TF_Status* status) {
646 if (type == TF_ATTR_STRING) {
647 tensorflow::StringPiece value;
648 if (!ParseStringValue(key, py_value, status, &value)) return false;
649 TFE_OpSetAttrString(op, key, value.data(), value.size());
650 } else if (type == TF_ATTR_INT) {
651 int64_t value;
652 if (!ParseInt64Value(key, py_value, status, &value)) return false;
653 TFE_OpSetAttrInt(op, key, value);
654 // attr_list_sizes is set for all int attributes (since at this point we are
655 // not aware if that attribute might be used to calculate the size of an
656 // output list or not).
657 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
658 } else if (type == TF_ATTR_FLOAT) {
659 float value;
660 if (!ParseFloatValue(key, py_value, status, &value)) return false;
661 TFE_OpSetAttrFloat(op, key, value);
662 } else if (type == TF_ATTR_BOOL) {
663 unsigned char value;
664 if (!ParseBoolValue(key, py_value, status, &value)) return false;
665 TFE_OpSetAttrBool(op, key, value);
666 } else if (type == TF_ATTR_TYPE) {
667 int value;
668 if (!ParseTypeValue(key, py_value, status, &value)) return false;
669 TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
670 } else if (type == TF_ATTR_SHAPE) {
671 if (py_value == Py_None) {
672 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
673 } else {
674 if (!PySequence_Check(py_value)) {
675 TF_SetStatus(status, TF_INVALID_ARGUMENT,
676 tensorflow::strings::StrCat(
677 "Expecting None or sequence value for attr", key,
678 ", got ", py_value->ob_type->tp_name)
679 .c_str());
680 return false;
681 }
682 const auto num_dims = TensorShapeNumDims(py_value);
683 if (num_dims == -1) {
684 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
685 return true;
686 }
687 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
688 for (int i = 0; i < num_dims; ++i) {
689 tensorflow::Safe_PyObjectPtr inner_py_value(
690 PySequence_ITEM(py_value, i));
691 // If an error is generated when iterating through object, we can
692 // sometimes get a nullptr.
693 if (inner_py_value.get() == Py_None) {
694 dims[i] = -1;
695 } else if (inner_py_value.get() == nullptr ||
696 !ParseDimensionValue(key, inner_py_value.get(), status,
697 &dims[i])) {
698 return false;
699 }
700 }
701 TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
702 }
703 if (!status->status.ok()) return false;
704 } else if (type == TF_ATTR_FUNC) {
705 // Allow:
706 // (1) String function name, OR
707 // (2) A Python object with a .name attribute
708 // (A crude test for being a
709 // tensorflow.python.framework.function._DefinedFunction)
710 // (which is what the various "defun" or "Defun" decorators do).
711 // And in the future also allow an object that can encapsulate
712 // the function name and its attribute values.
713 tensorflow::StringPiece func_name;
714 if (!ParseStringValue(key, py_value, status, &func_name)) {
715 PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
716 if (name_attr == nullptr ||
717 !ParseStringValue(key, name_attr, status, &func_name)) {
718 TF_SetStatus(
719 status, TF_INVALID_ARGUMENT,
720 tensorflow::strings::StrCat(
721 "unable to set function value attribute from a ",
722 py_value->ob_type->tp_name,
723 " object. If you think this is an error, please file an issue "
724 "at https://github.com/tensorflow/tensorflow/issues/new")
725 .c_str());
726 return false;
727 }
728 }
729 TF_SetStatus(status, TF_OK, "");
730 TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
731 } else {
732 TF_SetStatus(
733 status, TF_UNIMPLEMENTED,
734 tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
735 .c_str());
736 return false;
737 }
738 return true;
739 }
740
SetOpAttrScalarDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)741 void SetOpAttrScalarDefault(
742 TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
743 const char* attr_name,
744 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
745 TF_Status* status) {
746 SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
747 if (default_value.value_case() == tensorflow::AttrValue::kI) {
748 (*attr_list_sizes)[attr_name] = default_value.i();
749 }
750 }
751
752 // start_index is the index at which the Tuple/List attrs will start getting
753 // processed.
SetOpAttrs(TFE_Context * ctx,TFE_Op * op,PyObject * attrs,int start_index,TF_Status * out_status)754 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
755 TF_Status* out_status) {
756 if (attrs == Py_None) return;
757 Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
758 if ((len & 1) != 0) {
759 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
760 "Expecting attrs tuple to have even length.");
761 return;
762 }
763 // Parse attrs
764 for (Py_ssize_t i = 0; i < len; i += 2) {
765 PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
766 PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
767 #if PY_MAJOR_VERSION >= 3
768 const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
769 : PyUnicode_AsUTF8(py_key);
770 #else
771 const char* key = PyBytes_AsString(py_key);
772 #endif
773 unsigned char is_list = 0;
774 const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
775 if (!out_status->status.ok()) return;
776 if (is_list != 0) {
777 if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
778 return;
779 } else {
780 if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
781 return;
782 }
783 }
784 }
785
786 // This function will set the op attrs required. If an attr has the value of
787 // None, then it will read the AttrDef to get the default value and set that
788 // instead. Any failure in this function will simply fall back to the slow
789 // path.
SetOpAttrWithDefaults(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * attr_name,PyObject * attr_value,tensorflow::gtl::FlatMap<string,int64_t> * attr_list_sizes,TF_Status * status)790 void SetOpAttrWithDefaults(
791 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
792 const char* attr_name, PyObject* attr_value,
793 tensorflow::gtl::FlatMap<string, int64_t>* attr_list_sizes,
794 TF_Status* status) {
795 unsigned char is_list = 0;
796 const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
797 if (!status->status.ok()) return;
798 if (attr_value == Py_None) {
799 if (is_list != 0) {
800 SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
801 status);
802 } else {
803 SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
804 attr_list_sizes, status);
805 }
806 } else {
807 if (is_list != 0) {
808 SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
809 status);
810 } else {
811 SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
812 status);
813 }
814 }
815 }
816
GetPythonObjectFromInt(int num)817 PyObject* GetPythonObjectFromInt(int num) {
818 #if PY_MAJOR_VERSION >= 3
819 return PyLong_FromLong(num);
820 #else
821 return PyInt_FromLong(num);
822 #endif
823 }
824
825 // Python subclass of Exception that is created on not ok Status.
826 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
827 PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
828
829 // Python subclass of Exception that is created to signal fallback.
830 PyObject* fallback_exception_class = nullptr;
831
832 // Python function that returns input gradients given output gradients.
833 PyObject* gradient_function = nullptr;
834
835 // Python function that returns output gradients given input gradients.
836 PyObject* forward_gradient_function = nullptr;
837
838 static std::atomic<int64_t> _uid;
839
840 // This struct is responsible for marking thread_local storage as destroyed.
841 // Access to the `alive` field in already-destroyed ThreadLocalDestructionMarker
842 // is safe because it's a trivial type, so long as nobody creates a new
843 // thread_local in the space where now-destroyed marker used to be.
844 // Hopefully creating new thread_locals while destructing a thread is rare.
845 struct ThreadLocalDestructionMarker {
~ThreadLocalDestructionMarker__anon0e39441a0111::ThreadLocalDestructionMarker846 ~ThreadLocalDestructionMarker() { alive = false; }
847 bool alive = true;
848 };
849
850 } // namespace
851
GetStatus()852 TF_Status* GetStatus() {
853 TF_Status* maybe_status = ReleaseThreadLocalStatus();
854 if (maybe_status) {
855 TF_SetStatus(maybe_status, TF_OK, "");
856 return maybe_status;
857 } else {
858 return TF_NewStatus();
859 }
860 }
861
ReturnStatus(TF_Status * status)862 void ReturnStatus(TF_Status* status) {
863 TF_SetStatus(status, TF_OK, "");
864 thread_local_tf_status.reset(status);
865 }
866
TFE_Py_Execute(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_OutputTensorHandles * outputs,TF_Status * out_status)867 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
868 const char* op_name, TFE_InputTensorHandles* inputs,
869 PyObject* attrs, TFE_OutputTensorHandles* outputs,
870 TF_Status* out_status) {
871 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
872 /*cancellation_manager=*/nullptr, outputs,
873 out_status);
874 }
875
TFE_Py_ExecuteCancelable(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_CancellationManager * cancellation_manager,TFE_OutputTensorHandles * outputs,TF_Status * out_status)876 void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
877 const char* op_name,
878 TFE_InputTensorHandles* inputs, PyObject* attrs,
879 TFE_CancellationManager* cancellation_manager,
880 TFE_OutputTensorHandles* outputs,
881 TF_Status* out_status) {
882 tensorflow::profiler::TraceMe activity(
883 "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
884
885 TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
886
887 auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
888 if (!out_status->status.ok()) return;
889
890 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
891 tensorflow::StackTrace::kStackTraceInitialSize));
892
893 for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
894 TFE_OpAddInput(op, inputs->at(i), out_status);
895 }
896 if (cancellation_manager && out_status->status.ok()) {
897 TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
898 }
899 if (out_status->status.ok()) {
900 SetOpAttrs(ctx, op, attrs, 0, out_status);
901 }
902 Py_BEGIN_ALLOW_THREADS;
903
904 int num_outputs = outputs->size();
905
906 if (out_status->status.ok()) {
907 TFE_Execute(op, outputs->data(), &num_outputs, out_status);
908 }
909
910 if (out_status->status.ok()) {
911 outputs->resize(num_outputs);
912 } else {
913 TF_SetStatus(out_status, TF_GetCode(out_status),
914 tensorflow::strings::StrCat(TF_Message(out_status),
915 " [Op:", op_name, "]")
916 .c_str());
917 }
918
919 Py_END_ALLOW_THREADS;
920 }
921
TFE_Py_RegisterExceptionClass(PyObject * e)922 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
923 tensorflow::mutex_lock l(exception_class_mutex);
924 if (exception_class != nullptr) {
925 Py_DECREF(exception_class);
926 }
927 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
928 exception_class = nullptr;
929 PyErr_SetString(PyExc_TypeError,
930 "TFE_Py_RegisterExceptionClass: "
931 "Registered class should be subclass of Exception.");
932 return nullptr;
933 }
934
935 Py_INCREF(e);
936 exception_class = e;
937 Py_RETURN_NONE;
938 }
939
TFE_Py_RegisterFallbackExceptionClass(PyObject * e)940 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
941 if (fallback_exception_class != nullptr) {
942 Py_DECREF(fallback_exception_class);
943 }
944 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
945 fallback_exception_class = nullptr;
946 PyErr_SetString(PyExc_TypeError,
947 "TFE_Py_RegisterFallbackExceptionClass: "
948 "Registered class should be subclass of Exception.");
949 return nullptr;
950 } else {
951 Py_INCREF(e);
952 fallback_exception_class = e;
953 Py_RETURN_NONE;
954 }
955 }
956
TFE_Py_RegisterGradientFunction(PyObject * e)957 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
958 if (gradient_function != nullptr) {
959 Py_DECREF(gradient_function);
960 }
961 if (!PyCallable_Check(e)) {
962 gradient_function = nullptr;
963 PyErr_SetString(PyExc_TypeError,
964 "TFE_Py_RegisterGradientFunction: "
965 "Registered object should be function.");
966 return nullptr;
967 } else {
968 Py_INCREF(e);
969 gradient_function = e;
970 Py_RETURN_NONE;
971 }
972 }
973
TFE_Py_RegisterJVPFunction(PyObject * e)974 PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
975 if (forward_gradient_function != nullptr) {
976 Py_DECREF(forward_gradient_function);
977 }
978 if (!PyCallable_Check(e)) {
979 forward_gradient_function = nullptr;
980 PyErr_SetString(PyExc_TypeError,
981 "TFE_Py_RegisterJVPFunction: "
982 "Registered object should be function.");
983 return nullptr;
984 } else {
985 Py_INCREF(e);
986 forward_gradient_function = e;
987 Py_RETURN_NONE;
988 }
989 }
990
RaiseFallbackException(const char * message)991 void RaiseFallbackException(const char* message) {
992 if (fallback_exception_class != nullptr) {
993 PyErr_SetString(fallback_exception_class, message);
994 return;
995 }
996
997 PyErr_SetString(
998 PyExc_RuntimeError,
999 tensorflow::strings::StrCat(
1000 "Fallback exception type not set, attempting to fallback due to ",
1001 message)
1002 .data());
1003 }
1004
1005 // Format and return `status`' error message with the attached stack trace if
1006 // available. `status` must have an error.
FormatErrorStatusStackTrace(const tensorflow::Status & status)1007 std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
1008 tensorflow::DCheckPyGilState();
1009 DCHECK(!status.ok());
1010
1011 std::vector<tensorflow::StackFrame> stack_trace =
1012 tensorflow::errors::GetStackTrace(status);
1013
1014 if (stack_trace.empty()) return status.error_message();
1015
1016 PyObject* linecache = PyImport_ImportModule("linecache");
1017 PyObject* getline =
1018 PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
1019 DCHECK(getline);
1020
1021 std::ostringstream result;
1022 result << "Exception originated from\n\n";
1023
1024 for (const tensorflow::StackFrame& stack_frame : stack_trace) {
1025 PyObject* line_str_obj = PyObject_CallFunction(
1026 getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
1027 stack_frame.line_number);
1028 tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
1029 tensorflow::str_util::RemoveWhitespaceContext(&line_str);
1030 result << " File \"" << stack_frame.file_name << "\", line "
1031 << stack_frame.line_number << ", in " << stack_frame.function_name
1032 << '\n';
1033
1034 if (!line_str.empty()) result << " " << line_str << '\n';
1035 Py_XDECREF(line_str_obj);
1036 }
1037
1038 Py_DecRef(getline);
1039 Py_DecRef(linecache);
1040
1041 result << '\n' << status.error_message();
1042 return result.str();
1043 }
1044
1045 namespace tensorflow {
1046
MaybeRaiseExceptionFromTFStatus(TF_Status * status,PyObject * exception)1047 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
1048 if (status->status.ok()) return 0;
1049 const char* msg = TF_Message(status);
1050 if (exception == nullptr) {
1051 tensorflow::mutex_lock l(exception_class_mutex);
1052 if (exception_class != nullptr) {
1053 tensorflow::Safe_PyObjectPtr payloads(PyDict_New());
1054 for (const auto& payload :
1055 tensorflow::errors::GetPayloads(status->status)) {
1056 PyDict_SetItem(payloads.get(),
1057 PyBytes_FromString(payload.first.c_str()),
1058 PyBytes_FromString(payload.second.c_str()));
1059 }
1060 tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1061 "siO", FormatErrorStatusStackTrace(status->status).c_str(),
1062 TF_GetCode(status), payloads.get()));
1063 if (PyErr_Occurred()) {
1064 // NOTE: This hides the actual error (i.e. the reason `status` was not
1065 // TF_OK), but there is nothing we can do at this point since we can't
1066 // generate a reasonable error from the status.
1067 // Consider adding a message explaining this.
1068 return -1;
1069 }
1070 PyErr_SetObject(exception_class, val.get());
1071 return -1;
1072 } else {
1073 exception = PyExc_RuntimeError;
1074 }
1075 }
1076 // May be update already set exception.
1077 PyErr_SetString(exception, msg);
1078 return -1;
1079 }
1080
1081 } // namespace tensorflow
1082
MaybeRaiseExceptionFromStatus(const tensorflow::Status & status,PyObject * exception)1083 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
1084 PyObject* exception) {
1085 if (status.ok()) return 0;
1086 const char* msg = status.error_message().c_str();
1087 if (exception == nullptr) {
1088 tensorflow::mutex_lock l(exception_class_mutex);
1089 if (exception_class != nullptr) {
1090 tensorflow::Safe_PyObjectPtr payloads(PyDict_New());
1091 for (const auto& element : tensorflow::errors::GetPayloads(status)) {
1092 PyDict_SetItem(payloads.get(),
1093 PyBytes_FromString(element.first.c_str()),
1094 PyBytes_FromString(element.second.c_str()));
1095 }
1096 tensorflow::Safe_PyObjectPtr val(
1097 Py_BuildValue("siO", FormatErrorStatusStackTrace(status).c_str(),
1098 status.code(), payloads.get()));
1099 PyErr_SetObject(exception_class, val.get());
1100 return -1;
1101 } else {
1102 exception = PyExc_RuntimeError;
1103 }
1104 }
1105 // May be update already set exception.
1106 PyErr_SetString(exception, msg);
1107 return -1;
1108 }
1109
TFE_GetPythonString(PyObject * o)1110 const char* TFE_GetPythonString(PyObject* o) {
1111 #if PY_MAJOR_VERSION >= 3
1112 if (PyBytes_Check(o)) {
1113 return PyBytes_AsString(o);
1114 } else {
1115 return PyUnicode_AsUTF8(o);
1116 }
1117 #else
1118 return PyBytes_AsString(o);
1119 #endif
1120 }
1121
get_uid()1122 int64_t get_uid() { return _uid++; }
1123
TFE_Py_UID()1124 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
1125
TFE_DeleteContextCapsule(PyObject * context)1126 void TFE_DeleteContextCapsule(PyObject* context) {
1127 TFE_Context* ctx =
1128 reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
1129 auto op = ReleaseThreadLocalOp(ctx);
1130 op.reset();
1131 TFE_DeleteContext(ctx);
1132 }
1133
MakeInt(PyObject * integer)1134 static int64_t MakeInt(PyObject* integer) {
1135 #if PY_MAJOR_VERSION >= 3
1136 return PyLong_AsLong(integer);
1137 #else
1138 return PyInt_AsLong(integer);
1139 #endif
1140 }
1141
FastTensorId(PyObject * tensor)1142 static int64_t FastTensorId(PyObject* tensor) {
1143 if (EagerTensor_CheckExact(tensor)) {
1144 return PyEagerTensor_ID(tensor);
1145 }
1146 PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
1147 if (id_field == nullptr) {
1148 return -1;
1149 }
1150 int64_t id = MakeInt(id_field);
1151 Py_DECREF(id_field);
1152 return id;
1153 }
1154
1155 namespace tensorflow {
PyTensor_DataType(PyObject * tensor)1156 DataType PyTensor_DataType(PyObject* tensor) {
1157 if (EagerTensor_CheckExact(tensor)) {
1158 return PyEagerTensor_Dtype(tensor);
1159 } else {
1160 #if PY_MAJOR_VERSION < 3
1161 // Python 2.x:
1162 static PyObject* dtype_attr = PyString_InternFromString("dtype");
1163 static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
1164 #else
1165 // Python 3.x:
1166 static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
1167 static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
1168 #endif
1169 Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
1170 if (!dtype_field) {
1171 return DT_INVALID;
1172 }
1173
1174 Safe_PyObjectPtr enum_field(
1175 PyObject_GetAttr(dtype_field.get(), type_enum_attr));
1176 if (!enum_field) {
1177 return DT_INVALID;
1178 }
1179
1180 return static_cast<DataType>(MakeInt(enum_field.get()));
1181 }
1182 }
1183 } // namespace tensorflow
1184
1185 class PyTapeTensor {
1186 public:
PyTapeTensor(int64_t id,tensorflow::DataType dtype,const tensorflow::TensorShape & shape)1187 PyTapeTensor(int64_t id, tensorflow::DataType dtype,
1188 const tensorflow::TensorShape& shape)
1189 : id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(int64_t id,tensorflow::DataType dtype,PyObject * shape)1190 PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape)
1191 : id_(id), dtype_(dtype), shape_(shape) {
1192 Py_INCREF(absl::get<1>(shape_));
1193 }
PyTapeTensor(const PyTapeTensor & other)1194 PyTapeTensor(const PyTapeTensor& other) {
1195 id_ = other.id_;
1196 dtype_ = other.dtype_;
1197 shape_ = other.shape_;
1198 if (shape_.index() == 1) {
1199 Py_INCREF(absl::get<1>(shape_));
1200 }
1201 }
1202
~PyTapeTensor()1203 ~PyTapeTensor() {
1204 if (shape_.index() == 1) {
1205 Py_DECREF(absl::get<1>(shape_));
1206 }
1207 }
1208 PyObject* GetShape() const;
GetPyDType() const1209 PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
GetID() const1210 int64_t GetID() const { return id_; }
GetDType() const1211 tensorflow::DataType GetDType() const { return dtype_; }
1212
1213 PyObject* OnesLike() const;
1214 PyObject* ZerosLike() const;
1215
1216 private:
1217 int64_t id_;
1218 tensorflow::DataType dtype_;
1219
1220 // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
1221 // PyObject is the tensor itself. This is used to support tf.shape(tensor) for
1222 // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
1223 // tensors.
1224 absl::variant<tensorflow::TensorShape, PyObject*> shape_;
1225 };
1226
1227 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
1228
1229 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
1230 PyTapeTensor> {
1231 public:
PyVSpace(PyObject * py_vspace)1232 explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
1233 Py_INCREF(py_vspace_);
1234 }
1235
Initialize()1236 tensorflow::Status Initialize() {
1237 num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
1238 if (num_elements_ == nullptr) {
1239 return tensorflow::errors::InvalidArgument("invalid vspace");
1240 }
1241 aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
1242 if (aggregate_fn_ == nullptr) {
1243 return tensorflow::errors::InvalidArgument("invalid vspace");
1244 }
1245 zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
1246 if (zeros_fn_ == nullptr) {
1247 return tensorflow::errors::InvalidArgument("invalid vspace");
1248 }
1249 zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
1250 if (zeros_like_fn_ == nullptr) {
1251 return tensorflow::errors::InvalidArgument("invalid vspace");
1252 }
1253 ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
1254 if (ones_fn_ == nullptr) {
1255 return tensorflow::errors::InvalidArgument("invalid vspace");
1256 }
1257 ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
1258 if (ones_like_fn_ == nullptr) {
1259 return tensorflow::errors::InvalidArgument("invalid vspace");
1260 }
1261 graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
1262 if (graph_shape_fn_ == nullptr) {
1263 return tensorflow::errors::InvalidArgument("invalid vspace");
1264 }
1265 return ::tensorflow::OkStatus();
1266 }
1267
~PyVSpace()1268 ~PyVSpace() override {
1269 Py_XDECREF(num_elements_);
1270 Py_XDECREF(aggregate_fn_);
1271 Py_XDECREF(zeros_fn_);
1272 Py_XDECREF(zeros_like_fn_);
1273 Py_XDECREF(ones_fn_);
1274 Py_XDECREF(ones_like_fn_);
1275 Py_XDECREF(graph_shape_fn_);
1276
1277 Py_DECREF(py_vspace_);
1278 }
1279
NumElements(PyObject * tensor) const1280 int64_t NumElements(PyObject* tensor) const final {
1281 if (EagerTensor_CheckExact(tensor)) {
1282 return PyEagerTensor_NumElements(tensor);
1283 }
1284 PyObject* arglist =
1285 Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1286 PyObject* result = PyEval_CallObject(num_elements_, arglist);
1287 Py_DECREF(arglist);
1288 if (result == nullptr) {
1289 // The caller detects whether a python exception has been raised.
1290 return -1;
1291 }
1292 int64_t r = MakeInt(result);
1293 Py_DECREF(result);
1294 return r;
1295 }
1296
AggregateGradients(tensorflow::gtl::ArraySlice<PyObject * > gradient_tensors) const1297 PyObject* AggregateGradients(
1298 tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1299 PyObject* list = PyList_New(gradient_tensors.size());
1300 for (int i = 0; i < gradient_tensors.size(); ++i) {
1301 // Note: stealing a reference to the gradient tensors.
1302 CHECK(gradient_tensors[i] != nullptr);
1303 CHECK(gradient_tensors[i] != Py_None);
1304 PyList_SET_ITEM(list, i,
1305 reinterpret_cast<PyObject*>(gradient_tensors[i]));
1306 }
1307 PyObject* arglist = Py_BuildValue("(O)", list);
1308 CHECK(arglist != nullptr);
1309 PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1310 Py_DECREF(arglist);
1311 Py_DECREF(list);
1312 return result;
1313 }
1314
TensorId(PyObject * tensor) const1315 int64_t TensorId(PyObject* tensor) const final {
1316 return FastTensorId(tensor);
1317 }
1318
MarkAsResult(PyObject * gradient) const1319 void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1320
Ones(PyObject * shape,PyObject * dtype) const1321 PyObject* Ones(PyObject* shape, PyObject* dtype) const {
1322 if (PyErr_Occurred()) {
1323 return nullptr;
1324 }
1325 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1326 PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1327 Py_DECREF(arg_list);
1328 return result;
1329 }
1330
OnesLike(PyObject * tensor) const1331 PyObject* OnesLike(PyObject* tensor) const {
1332 if (PyErr_Occurred()) {
1333 return nullptr;
1334 }
1335 return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
1336 }
1337
1338 // Builds a tensor filled with ones with the same shape and dtype as `t`.
BuildOnesLike(const PyTapeTensor & t,PyObject ** result) const1339 Status BuildOnesLike(const PyTapeTensor& t,
1340 PyObject** result) const override {
1341 *result = t.OnesLike();
1342 return ::tensorflow::OkStatus();
1343 }
1344
Zeros(PyObject * shape,PyObject * dtype) const1345 PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
1346 if (PyErr_Occurred()) {
1347 return nullptr;
1348 }
1349 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1350 PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1351 Py_DECREF(arg_list);
1352 return result;
1353 }
1354
ZerosLike(PyObject * tensor) const1355 PyObject* ZerosLike(PyObject* tensor) const {
1356 if (PyErr_Occurred()) {
1357 return nullptr;
1358 }
1359 return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
1360 }
1361
GraphShape(PyObject * tensor) const1362 PyObject* GraphShape(PyObject* tensor) const {
1363 PyObject* arg_list = Py_BuildValue("(O)", tensor);
1364 PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1365 Py_DECREF(arg_list);
1366 return result;
1367 }
1368
CallBackwardFunction(const string & op_type,PyBackwardFunction * backward_function,const std::vector<int64_t> & unneeded_gradients,tensorflow::gtl::ArraySlice<PyObject * > output_gradients,absl::Span<PyObject * > result) const1369 tensorflow::Status CallBackwardFunction(
1370 const string& op_type, PyBackwardFunction* backward_function,
1371 const std::vector<int64_t>& unneeded_gradients,
1372 tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1373 absl::Span<PyObject*> result) const final {
1374 PyObject* grads = PyTuple_New(output_gradients.size());
1375 for (int i = 0; i < output_gradients.size(); ++i) {
1376 if (output_gradients[i] == nullptr) {
1377 Py_INCREF(Py_None);
1378 PyTuple_SET_ITEM(grads, i, Py_None);
1379 } else {
1380 PyTuple_SET_ITEM(grads, i,
1381 reinterpret_cast<PyObject*>(output_gradients[i]));
1382 }
1383 }
1384 PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
1385 Py_DECREF(grads);
1386 if (py_result == nullptr) {
1387 return tensorflow::errors::Internal("gradient function threw exceptions");
1388 }
1389 PyObject* seq =
1390 PySequence_Fast(py_result, "expected a sequence of gradients");
1391 if (seq == nullptr) {
1392 return tensorflow::errors::InvalidArgument(
1393 "gradient function did not return a list");
1394 }
1395 int len = PySequence_Fast_GET_SIZE(seq);
1396 if (len != result.size()) {
1397 return tensorflow::errors::Internal(
1398 "Recorded operation '", op_type,
1399 "' returned too few gradients. Expected ", result.size(),
1400 " but received ", len);
1401 }
1402 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1403 VLOG(1) << "Gradient length is " << len;
1404 for (int i = 0; i < len; ++i) {
1405 PyObject* item = seq_array[i];
1406 if (item == Py_None) {
1407 result[i] = nullptr;
1408 } else {
1409 Py_INCREF(item);
1410 result[i] = item;
1411 }
1412 }
1413 Py_DECREF(seq);
1414 Py_DECREF(py_result);
1415 return ::tensorflow::OkStatus();
1416 }
1417
DeleteGradient(PyObject * tensor) const1418 void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1419
TapeTensorFromGradient(PyObject * tensor) const1420 PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
1421 return TapeTensorFromTensor(tensor);
1422 }
1423
1424 private:
1425 PyObject* py_vspace_;
1426
1427 PyObject* num_elements_;
1428 PyObject* aggregate_fn_;
1429 PyObject* zeros_fn_;
1430 PyObject* zeros_like_fn_;
1431 PyObject* ones_fn_;
1432 PyObject* ones_like_fn_;
1433 PyObject* graph_shape_fn_;
1434 };
1435 PyVSpace* py_vspace = nullptr;
1436
1437 bool HasAccumulator();
1438
TFE_Py_RegisterVSpace(PyObject * e)1439 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1440 if (py_vspace != nullptr) {
1441 if (HasAccumulator()) {
1442 // Accumulators reference py_vspace, so we can't swap it out while one is
1443 // active. This is unlikely to ever happen.
1444 MaybeRaiseExceptionFromStatus(
1445 tensorflow::errors::Internal(
1446 "Can't change the vspace implementation while a "
1447 "forward accumulator is active."),
1448 nullptr);
1449 }
1450 delete py_vspace;
1451 }
1452
1453 py_vspace = new PyVSpace(e);
1454 auto status = py_vspace->Initialize();
1455 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1456 delete py_vspace;
1457 return nullptr;
1458 }
1459
1460 Py_RETURN_NONE;
1461 }
1462
GetShape() const1463 PyObject* PyTapeTensor::GetShape() const {
1464 if (shape_.index() == 0) {
1465 auto& shape = absl::get<0>(shape_);
1466 PyObject* py_shape = PyTuple_New(shape.dims());
1467 for (int i = 0; i < shape.dims(); ++i) {
1468 PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1469 }
1470
1471 return py_shape;
1472 }
1473
1474 return py_vspace->GraphShape(absl::get<1>(shape_));
1475 }
1476
OnesLike() const1477 PyObject* PyTapeTensor::OnesLike() const {
1478 if (shape_.index() == 1) {
1479 PyObject* tensor = absl::get<1>(shape_);
1480 return py_vspace->OnesLike(tensor);
1481 }
1482 PyObject* py_shape = GetShape();
1483 PyObject* dtype_field = GetPyDType();
1484 PyObject* result = py_vspace->Ones(py_shape, dtype_field);
1485 Py_DECREF(dtype_field);
1486 Py_DECREF(py_shape);
1487 return result;
1488 }
1489
ZerosLike() const1490 PyObject* PyTapeTensor::ZerosLike() const {
1491 if (GetDType() == tensorflow::DT_RESOURCE) {
1492 // Gradient functions for ops which return resource tensors accept
1493 // None. This is the behavior of py_vspace->Zeros, but checking here avoids
1494 // issues with ZerosLike.
1495 Py_RETURN_NONE;
1496 }
1497 if (shape_.index() == 1) {
1498 PyObject* tensor = absl::get<1>(shape_);
1499 return py_vspace->ZerosLike(tensor);
1500 }
1501 PyObject* py_shape = GetShape();
1502 PyObject* dtype_field = GetPyDType();
1503 PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
1504 Py_DECREF(dtype_field);
1505 Py_DECREF(py_shape);
1506 return result;
1507 }
1508
1509 // Keeps track of all variables that have been accessed during execution.
1510 class VariableWatcher {
1511 public:
VariableWatcher()1512 VariableWatcher() {}
1513
~VariableWatcher()1514 ~VariableWatcher() {
1515 for (const IdAndVariable& v : watched_variables_) {
1516 Py_DECREF(v.variable);
1517 }
1518 }
1519
WatchVariable(PyObject * v)1520 int64_t WatchVariable(PyObject* v) {
1521 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1522 if (handle == nullptr) {
1523 return -1;
1524 }
1525 int64_t id = FastTensorId(handle.get());
1526
1527 tensorflow::mutex_lock l(watched_variables_mu_);
1528 auto insert_result = watched_variables_.emplace(id, v);
1529
1530 if (insert_result.second) {
1531 // Only increment the reference count if we aren't already watching this
1532 // variable.
1533 Py_INCREF(v);
1534 }
1535
1536 return id;
1537 }
1538
GetVariablesAsPyTuple()1539 PyObject* GetVariablesAsPyTuple() {
1540 tensorflow::mutex_lock l(watched_variables_mu_);
1541 PyObject* result = PyTuple_New(watched_variables_.size());
1542 Py_ssize_t pos = 0;
1543 for (const IdAndVariable& id_and_variable : watched_variables_) {
1544 PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1545 Py_INCREF(id_and_variable.variable);
1546 }
1547 return result;
1548 }
1549
1550 private:
1551 // We store an IdAndVariable in the map since the map needs to be locked
1552 // during insert, but should not call back into python during insert to avoid
1553 // deadlocking with the GIL.
1554 struct IdAndVariable {
1555 int64_t id;
1556 PyObject* variable;
1557
IdAndVariableVariableWatcher::IdAndVariable1558 IdAndVariable(int64_t id, PyObject* variable)
1559 : id(id), variable(variable) {}
1560 };
1561 struct CompareById {
operator ()VariableWatcher::CompareById1562 bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1563 return lhs.id < rhs.id;
1564 }
1565 };
1566
1567 tensorflow::mutex watched_variables_mu_;
1568 std::set<IdAndVariable, CompareById> watched_variables_
1569 TF_GUARDED_BY(watched_variables_mu_);
1570 };
1571
1572 class GradientTape
1573 : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1574 PyTapeTensor> {
1575 public:
GradientTape(bool persistent,bool watch_accessed_variables)1576 explicit GradientTape(bool persistent, bool watch_accessed_variables)
1577 : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1578 PyTapeTensor>(persistent),
1579 watch_accessed_variables_(watch_accessed_variables) {}
1580
~GradientTape()1581 virtual ~GradientTape() {}
1582
VariableAccessed(PyObject * v)1583 void VariableAccessed(PyObject* v) {
1584 if (watch_accessed_variables_) {
1585 WatchVariable(v);
1586 }
1587 }
1588
WatchVariable(PyObject * v)1589 void WatchVariable(PyObject* v) {
1590 int64_t id = variable_watcher_.WatchVariable(v);
1591
1592 if (!PyErr_Occurred()) {
1593 this->Watch(id);
1594 }
1595 }
1596
GetVariablesAsPyTuple()1597 PyObject* GetVariablesAsPyTuple() {
1598 return variable_watcher_.GetVariablesAsPyTuple();
1599 }
1600
1601 private:
1602 bool watch_accessed_variables_;
1603 VariableWatcher variable_watcher_;
1604 };
1605
1606 typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
1607 PyTapeTensor>
1608 ForwardAccumulator;
1609
1610 // Incremented when a GradientTape or accumulator is newly added to a set, and
1611 // used to enforce an ordering between them.
1612 std::atomic_uint_fast64_t tape_nesting_id_counter(0);
1613
1614 typedef struct {
1615 PyObject_HEAD
1616 /* Type-specific fields go here. */
1617 GradientTape* tape;
1618 // A nesting order between GradientTapes and ForwardAccumulators, used to
1619 // ensure that GradientTapes do not watch the products of outer
1620 // ForwardAccumulators.
1621 int64_t nesting_id;
1622 } TFE_Py_Tape;
1623
TFE_Py_Tape_Delete(PyObject * tape)1624 static void TFE_Py_Tape_Delete(PyObject* tape) {
1625 delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1626 Py_TYPE(tape)->tp_free(tape);
1627 }
1628
1629 static PyTypeObject TFE_Py_Tape_Type = {
1630 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1631 sizeof(TFE_Py_Tape), /* tp_basicsize */
1632 0, /* tp_itemsize */
1633 &TFE_Py_Tape_Delete, /* tp_dealloc */
1634 #if PY_VERSION_HEX < 0x03080000
1635 nullptr, /* tp_print */
1636 #else
1637 0, /* tp_vectorcall_offset */
1638 #endif
1639 nullptr, /* tp_getattr */
1640 nullptr, /* tp_setattr */
1641 nullptr, /* tp_reserved */
1642 nullptr, /* tp_repr */
1643 nullptr, /* tp_as_number */
1644 nullptr, /* tp_as_sequence */
1645 nullptr, /* tp_as_mapping */
1646 nullptr, /* tp_hash */
1647 nullptr, /* tp_call */
1648 nullptr, /* tp_str */
1649 nullptr, /* tp_getattro */
1650 nullptr, /* tp_setattro */
1651 nullptr, /* tp_as_buffer */
1652 Py_TPFLAGS_DEFAULT, /* tp_flags */
1653 "TFE_Py_Tape objects", /* tp_doc */
1654 };
1655
1656 typedef struct {
1657 PyObject_HEAD
1658 /* Type-specific fields go here. */
1659 ForwardAccumulator* accumulator;
1660 // A nesting order between GradientTapes and ForwardAccumulators, used to
1661 // ensure that GradientTapes do not watch the products of outer
1662 // ForwardAccumulators.
1663 int64_t nesting_id;
1664 } TFE_Py_ForwardAccumulator;
1665
TFE_Py_ForwardAccumulatorDelete(PyObject * accumulator)1666 static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
1667 delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
1668 Py_TYPE(accumulator)->tp_free(accumulator);
1669 }
1670
1671 static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
1672 PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
1673 sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
1674 0, /* tp_itemsize */
1675 &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
1676 #if PY_VERSION_HEX < 0x03080000
1677 nullptr, /* tp_print */
1678 #else
1679 0, /* tp_vectorcall_offset */
1680 #endif
1681 nullptr, /* tp_getattr */
1682 nullptr, /* tp_setattr */
1683 nullptr, /* tp_reserved */
1684 nullptr, /* tp_repr */
1685 nullptr, /* tp_as_number */
1686 nullptr, /* tp_as_sequence */
1687 nullptr, /* tp_as_mapping */
1688 nullptr, /* tp_hash */
1689 nullptr, /* tp_call */
1690 nullptr, /* tp_str */
1691 nullptr, /* tp_getattro */
1692 nullptr, /* tp_setattro */
1693 nullptr, /* tp_as_buffer */
1694 Py_TPFLAGS_DEFAULT, /* tp_flags */
1695 "TFE_Py_ForwardAccumulator objects", /* tp_doc */
1696 };
1697
1698 typedef struct {
1699 PyObject_HEAD
1700 /* Type-specific fields go here. */
1701 VariableWatcher* variable_watcher;
1702 } TFE_Py_VariableWatcher;
1703
TFE_Py_VariableWatcher_Delete(PyObject * variable_watcher)1704 static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
1705 delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
1706 ->variable_watcher;
1707 Py_TYPE(variable_watcher)->tp_free(variable_watcher);
1708 }
1709
1710 static PyTypeObject TFE_Py_VariableWatcher_Type = {
1711 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
1712 sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
1713 0, /* tp_itemsize */
1714 &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
1715 #if PY_VERSION_HEX < 0x03080000
1716 nullptr, /* tp_print */
1717 #else
1718 0, /* tp_vectorcall_offset */
1719 #endif
1720 nullptr, /* tp_getattr */
1721 nullptr, /* tp_setattr */
1722 nullptr, /* tp_reserved */
1723 nullptr, /* tp_repr */
1724 nullptr, /* tp_as_number */
1725 nullptr, /* tp_as_sequence */
1726 nullptr, /* tp_as_mapping */
1727 nullptr, /* tp_hash */
1728 nullptr, /* tp_call */
1729 nullptr, /* tp_str */
1730 nullptr, /* tp_getattro */
1731 nullptr, /* tp_setattro */
1732 nullptr, /* tp_as_buffer */
1733 Py_TPFLAGS_DEFAULT, /* tp_flags */
1734 "TFE_Py_VariableWatcher objects", /* tp_doc */
1735 };
1736
1737 // Note: in the current design no mutex is needed here because of the python
1738 // GIL, which is always held when any TFE_Py_* methods are called. We should
1739 // revisit this if/when decide to not hold the GIL while manipulating the tape
1740 // stack.
GetTapeSet()1741 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1742 thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
1743 tape_set;
1744 thread_local ThreadLocalDestructionMarker marker;
1745 if (!marker.alive) {
1746 // This thread is being destroyed. It is unsafe to access tape_set.
1747 return nullptr;
1748 }
1749 if (tape_set == nullptr) {
1750 tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
1751 }
1752 return tape_set.get();
1753 }
1754
1755 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
GetVariableWatcherSet()1756 GetVariableWatcherSet() {
1757 thread_local std::unique_ptr<
1758 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
1759 variable_watcher_set;
1760 thread_local ThreadLocalDestructionMarker marker;
1761 if (!marker.alive) {
1762 // This thread is being destroyed. It is unsafe to access
1763 // variable_watcher_set.
1764 return nullptr;
1765 }
1766 if (variable_watcher_set == nullptr) {
1767 variable_watcher_set.reset(
1768 new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
1769 }
1770 return variable_watcher_set.get();
1771 }
1772
1773 // A linked hash set, where iteration is in insertion order.
1774 //
1775 // Nested accumulators rely on op recording happening in insertion order, so an
1776 // unordered data structure like CompactPointerSet is not suitable. Outer
1777 // accumulators need to observe operations first so they know to watch the inner
1778 // accumulator's jvp computation.
1779 //
1780 // Not thread safe.
1781 class AccumulatorSet {
1782 public:
1783 // Returns true if `element` was newly inserted, false if it already exists.
insert(TFE_Py_ForwardAccumulator * element)1784 bool insert(TFE_Py_ForwardAccumulator* element) {
1785 if (map_.find(element) != map_.end()) {
1786 return false;
1787 }
1788 ListType::iterator it = ordered_.insert(ordered_.end(), element);
1789 map_.insert(std::make_pair(element, it));
1790 return true;
1791 }
1792
erase(TFE_Py_ForwardAccumulator * element)1793 void erase(TFE_Py_ForwardAccumulator* element) {
1794 MapType::iterator existing = map_.find(element);
1795 if (existing == map_.end()) {
1796 return;
1797 }
1798 ListType::iterator list_position = existing->second;
1799 map_.erase(existing);
1800 ordered_.erase(list_position);
1801 }
1802
empty() const1803 bool empty() const { return ordered_.empty(); }
1804
size() const1805 size_t size() const { return ordered_.size(); }
1806
1807 private:
1808 typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
1809 typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
1810 ListType::iterator>
1811 MapType;
1812
1813 public:
1814 typedef ListType::const_iterator const_iterator;
1815 typedef ListType::const_reverse_iterator const_reverse_iterator;
1816
begin() const1817 const_iterator begin() const { return ordered_.begin(); }
end() const1818 const_iterator end() const { return ordered_.end(); }
1819
rbegin() const1820 const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
rend() const1821 const_reverse_iterator rend() const { return ordered_.rend(); }
1822
1823 private:
1824 MapType map_;
1825 ListType ordered_;
1826 };
1827
GetAccumulatorSet()1828 AccumulatorSet* GetAccumulatorSet() {
1829 thread_local std::unique_ptr<AccumulatorSet> accumulator_set;
1830 thread_local ThreadLocalDestructionMarker marker;
1831 if (!marker.alive) {
1832 // This thread is being destroyed. It is unsafe to access accumulator_set.
1833 return nullptr;
1834 }
1835 if (accumulator_set == nullptr) {
1836 accumulator_set.reset(new AccumulatorSet);
1837 }
1838 return accumulator_set.get();
1839 }
1840
HasAccumulator()1841 inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
1842
HasGradientTape()1843 inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
1844
HasAccumulatorOrTape()1845 inline bool HasAccumulatorOrTape() {
1846 return HasGradientTape() || HasAccumulator();
1847 }
1848
1849 // A safe copy of a set, used for tapes and accumulators. The copy is not
1850 // affected by other python threads changing the set of active tapes.
1851 template <typename ContainerType>
1852 class SafeSetCopy {
1853 public:
SafeSetCopy(const ContainerType & to_copy)1854 explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
1855 for (auto* member : set_copy_) {
1856 Py_INCREF(member);
1857 }
1858 }
1859
~SafeSetCopy()1860 ~SafeSetCopy() {
1861 for (auto* member : set_copy_) {
1862 Py_DECREF(member);
1863 }
1864 }
1865
begin() const1866 typename ContainerType::const_iterator begin() const {
1867 return set_copy_.begin();
1868 }
1869
end() const1870 typename ContainerType::const_iterator end() const { return set_copy_.end(); }
1871
empty() const1872 bool empty() const { return set_copy_.empty(); }
size() const1873 size_t size() const { return set_copy_.size(); }
1874
1875 protected:
1876 ContainerType set_copy_;
1877 };
1878
1879 class SafeTapeSet
1880 : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
1881 public:
SafeTapeSet()1882 SafeTapeSet()
1883 : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
1884 *GetTapeSet()) {}
1885 };
1886
1887 class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
1888 public:
SafeAccumulatorSet()1889 SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
1890
rbegin() const1891 typename AccumulatorSet::const_reverse_iterator rbegin() const {
1892 return set_copy_.rbegin();
1893 }
1894
rend() const1895 typename AccumulatorSet::const_reverse_iterator rend() const {
1896 return set_copy_.rend();
1897 }
1898 };
1899
1900 class SafeVariableWatcherSet
1901 : public SafeSetCopy<
1902 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
1903 public:
SafeVariableWatcherSet()1904 SafeVariableWatcherSet()
1905 : SafeSetCopy<
1906 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
1907 *GetVariableWatcherSet()) {}
1908 };
1909
ThreadTapeIsStopped()1910 bool* ThreadTapeIsStopped() {
1911 thread_local bool thread_tape_is_stopped{false};
1912 return &thread_tape_is_stopped;
1913 }
1914
TFE_Py_TapeSetStopOnThread()1915 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1916
TFE_Py_TapeSetRestartOnThread()1917 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1918
TFE_Py_TapeSetIsStopped()1919 PyObject* TFE_Py_TapeSetIsStopped() {
1920 if (*ThreadTapeIsStopped()) {
1921 Py_RETURN_TRUE;
1922 }
1923 Py_RETURN_FALSE;
1924 }
1925
TFE_Py_TapeSetNew(PyObject * persistent,PyObject * watch_accessed_variables)1926 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1927 PyObject* watch_accessed_variables) {
1928 TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1929 if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1930 TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1931 tape->tape = new GradientTape(persistent == Py_True,
1932 watch_accessed_variables == Py_True);
1933 Py_INCREF(tape);
1934 tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1935 GetTapeSet()->insert(tape);
1936 return reinterpret_cast<PyObject*>(tape);
1937 }
1938
TFE_Py_TapeSetAdd(PyObject * tape)1939 void TFE_Py_TapeSetAdd(PyObject* tape) {
1940 Py_INCREF(tape);
1941 TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
1942 if (!GetTapeSet()->insert(tfe_tape).second) {
1943 // Already exists in the tape set.
1944 Py_DECREF(tape);
1945 } else {
1946 tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1947 }
1948 }
1949
TFE_Py_TapeSetIsEmpty()1950 PyObject* TFE_Py_TapeSetIsEmpty() {
1951 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
1952 Py_RETURN_TRUE;
1953 }
1954 Py_RETURN_FALSE;
1955 }
1956
TFE_Py_TapeSetRemove(PyObject * tape)1957 void TFE_Py_TapeSetRemove(PyObject* tape) {
1958 auto* stack = GetTapeSet();
1959 if (stack != nullptr) {
1960 stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1961 }
1962 // We kept a reference to the tape in the set to ensure it wouldn't get
1963 // deleted under us; cleaning it up here.
1964 Py_DECREF(tape);
1965 }
1966
MakeIntList(PyObject * list)1967 static std::vector<int64_t> MakeIntList(PyObject* list) {
1968 if (list == Py_None) {
1969 return {};
1970 }
1971 PyObject* seq = PySequence_Fast(list, "expected a sequence");
1972 if (seq == nullptr) {
1973 return {};
1974 }
1975 int len = PySequence_Size(list);
1976 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1977 std::vector<int64_t> tensor_ids;
1978 tensor_ids.reserve(len);
1979 for (int i = 0; i < len; ++i) {
1980 PyObject* item = seq_array[i];
1981 #if PY_MAJOR_VERSION >= 3
1982 if (PyLong_Check(item)) {
1983 #else
1984 if (PyLong_Check(item) || PyInt_Check(item)) {
1985 #endif
1986 int64_t id = MakeInt(item);
1987 tensor_ids.push_back(id);
1988 } else {
1989 tensor_ids.push_back(-1);
1990 }
1991 }
1992 Py_DECREF(seq);
1993 return tensor_ids;
1994 }
1995
1996 // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
1997 // null. Returns true on success and false on a Python exception.
1998 bool TensorShapesAndDtypes(PyObject* tensors, std::vector<int64_t>* tensor_ids,
1999 std::vector<tensorflow::DataType>* dtypes) {
2000 tensorflow::Safe_PyObjectPtr seq(
2001 PySequence_Fast(tensors, "expected a sequence"));
2002 if (seq == nullptr) {
2003 return false;
2004 }
2005 int len = PySequence_Fast_GET_SIZE(seq.get());
2006 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2007 tensor_ids->reserve(len);
2008 dtypes->reserve(len);
2009 for (int i = 0; i < len; ++i) {
2010 PyObject* item = seq_array[i];
2011 tensor_ids->push_back(FastTensorId(item));
2012 dtypes->push_back(tensorflow::PyTensor_DataType(item));
2013 }
2014 return true;
2015 }
2016
2017 bool TapeCouldPossiblyRecord(PyObject* tensors) {
2018 if (tensors == Py_None) {
2019 return false;
2020 }
2021 if (*ThreadTapeIsStopped()) {
2022 return false;
2023 }
2024 if (!HasAccumulatorOrTape()) {
2025 return false;
2026 }
2027 return true;
2028 }
2029
2030 bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
2031
2032 bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
2033
2034 PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
2035 if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
2036 Py_RETURN_FALSE;
2037 }
2038 // TODO(apassos) consider not building a list and changing the API to check
2039 // each tensor individually.
2040 std::vector<int64_t> tensor_ids;
2041 std::vector<tensorflow::DataType> dtypes;
2042 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2043 return nullptr;
2044 }
2045 auto& tape_set = *GetTapeSet();
2046 for (TFE_Py_Tape* tape : tape_set) {
2047 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2048 Py_RETURN_TRUE;
2049 }
2050 }
2051
2052 Py_RETURN_FALSE;
2053 }
2054
2055 PyObject* TFE_Py_ForwardAccumulatorPushState() {
2056 auto& forward_accumulators = *GetAccumulatorSet();
2057 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2058 accumulator->accumulator->PushState();
2059 }
2060 Py_RETURN_NONE;
2061 }
2062
2063 PyObject* TFE_Py_ForwardAccumulatorPopState() {
2064 auto& forward_accumulators = *GetAccumulatorSet();
2065 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2066 accumulator->accumulator->PopState();
2067 }
2068 Py_RETURN_NONE;
2069 }
2070
2071 PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
2072 if (!TapeCouldPossiblyRecord(tensors)) {
2073 return GetPythonObjectFromInt(0);
2074 }
2075 std::vector<int64_t> tensor_ids;
2076 std::vector<tensorflow::DataType> dtypes;
2077 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2078 return nullptr;
2079 }
2080
2081 // If there is a persistent tape watching, or if there are multiple tapes
2082 // watching, we'll return immediately indicating that higher-order tape
2083 // gradients are possible.
2084 bool some_tape_watching = false;
2085 if (CouldBackprop()) {
2086 auto& tape_set = *GetTapeSet();
2087 for (TFE_Py_Tape* tape : tape_set) {
2088 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2089 if (tape->tape->IsPersistent() || some_tape_watching) {
2090 // Either this is the second tape watching, or this tape is
2091 // persistent: higher-order gradients are possible.
2092 return GetPythonObjectFromInt(2);
2093 }
2094 some_tape_watching = true;
2095 }
2096 }
2097 }
2098 if (CouldForwardprop()) {
2099 auto& forward_accumulators = *GetAccumulatorSet();
2100 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2101 if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
2102 if (some_tape_watching) {
2103 // This is the second tape watching: higher-order gradients are
2104 // possible. Note that there's no equivalent of persistence for
2105 // forward-mode.
2106 return GetPythonObjectFromInt(2);
2107 }
2108 some_tape_watching = true;
2109 }
2110 }
2111 }
2112 if (some_tape_watching) {
2113 // There's exactly one non-persistent tape. The user can request first-order
2114 // gradients but won't be able to get higher-order tape gradients.
2115 return GetPythonObjectFromInt(1);
2116 } else {
2117 // There are no tapes. The user can't request tape gradients.
2118 return GetPythonObjectFromInt(0);
2119 }
2120 }
2121
2122 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
2123 if (!CouldBackprop()) {
2124 return;
2125 }
2126 int64_t tensor_id = FastTensorId(tensor);
2127 if (PyErr_Occurred()) {
2128 return;
2129 }
2130 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
2131 }
2132
2133 bool ListContainsNone(PyObject* list) {
2134 if (list == Py_None) return true;
2135 tensorflow::Safe_PyObjectPtr seq(
2136 PySequence_Fast(list, "expected a sequence"));
2137 if (seq == nullptr) {
2138 return false;
2139 }
2140
2141 int len = PySequence_Size(list);
2142 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2143 for (int i = 0; i < len; ++i) {
2144 PyObject* item = seq_array[i];
2145 if (item == Py_None) return true;
2146 }
2147
2148 return false;
2149 }
2150
2151 // As an optimization, the tape generally keeps only the shape and dtype of
2152 // tensors, and uses this information to generate ones/zeros tensors. However,
2153 // some tensors require OnesLike/ZerosLike because their gradients do not match
2154 // their inference shape/dtype.
2155 bool DTypeNeedsHandleData(tensorflow::DataType dtype) {
2156 return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE;
2157 }
2158
2159 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
2160 if (EagerTensor_CheckExact(tensor)) {
2161 tensorflow::ImmediateExecutionTensorHandle* handle =
2162 tensorflow::unwrap(EagerTensor_Handle(tensor));
2163 int64_t id = PyEagerTensor_ID(tensor);
2164 tensorflow::DataType dtype =
2165 static_cast<tensorflow::DataType>(handle->DataType());
2166 if (DTypeNeedsHandleData(dtype)) {
2167 return PyTapeTensor(id, dtype, tensor);
2168 }
2169
2170 tensorflow::TensorShape tensor_shape;
2171 int num_dims;
2172 tensorflow::Status status = handle->NumDims(&num_dims);
2173 if (status.ok()) {
2174 for (int i = 0; i < num_dims; ++i) {
2175 int64_t dim_size;
2176 status = handle->Dim(i, &dim_size);
2177 if (!status.ok()) break;
2178 tensor_shape.AddDim(dim_size);
2179 }
2180 }
2181
2182 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2183 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2184 tensorflow::TensorShape({}));
2185 } else {
2186 return PyTapeTensor(id, dtype, tensor_shape);
2187 }
2188 }
2189 int64_t id = FastTensorId(tensor);
2190 if (PyErr_Occurred()) {
2191 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2192 tensorflow::TensorShape({}));
2193 }
2194 PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
2195 PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
2196 Py_DECREF(dtype_object);
2197 tensorflow::DataType dtype =
2198 static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
2199 Py_DECREF(dtype_enum);
2200 if (PyErr_Occurred()) {
2201 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2202 tensorflow::TensorShape({}));
2203 }
2204 static char _shape_tuple[] = "_shape_tuple";
2205 tensorflow::Safe_PyObjectPtr shape_tuple(
2206 PyObject_CallMethod(tensor, _shape_tuple, nullptr));
2207 if (PyErr_Occurred()) {
2208 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2209 tensorflow::TensorShape({}));
2210 }
2211
2212 if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) {
2213 return PyTapeTensor(id, dtype, tensor);
2214 }
2215
2216 auto l = MakeIntList(shape_tuple.get());
2217 // Replace -1, which represents accidental Nones which can occur in graph mode
2218 // and can cause errors in shape construction with 0s.
2219 for (auto& c : l) {
2220 if (c < 0) {
2221 c = 0;
2222 }
2223 }
2224 tensorflow::TensorShape shape(l);
2225 return PyTapeTensor(id, dtype, shape);
2226 }
2227
2228 // Populates output_info from output_seq, which must come from PySequence_Fast.
2229 //
2230 // Does not take ownership of output_seq. Returns true on success and false if a
2231 // Python exception has been set.
2232 bool TapeTensorsFromTensorSequence(PyObject* output_seq,
2233 std::vector<PyTapeTensor>* output_info) {
2234 Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
2235 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2236 output_info->reserve(output_len);
2237 for (Py_ssize_t i = 0; i < output_len; ++i) {
2238 output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
2239 if (PyErr_Occurred() != nullptr) {
2240 return false;
2241 }
2242 }
2243 return true;
2244 }
2245
2246 std::vector<int64_t> MakeTensorIDList(PyObject* tensors) {
2247 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2248 if (seq == nullptr) {
2249 return {};
2250 }
2251 int len = PySequence_Fast_GET_SIZE(seq);
2252 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2253 std::vector<int64_t> list;
2254 list.reserve(len);
2255 for (int i = 0; i < len; ++i) {
2256 PyObject* tensor = seq_array[i];
2257 list.push_back(FastTensorId(tensor));
2258 if (PyErr_Occurred()) {
2259 Py_DECREF(seq);
2260 return list;
2261 }
2262 }
2263 Py_DECREF(seq);
2264 return list;
2265 }
2266
2267 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
2268 if (!CouldBackprop()) {
2269 return;
2270 }
2271 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2272 tape->tape->VariableAccessed(variable);
2273 }
2274 }
2275
2276 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
2277 if (!CouldBackprop()) {
2278 return;
2279 }
2280 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
2281 }
2282
2283 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
2284 return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
2285 }
2286
2287 PyObject* TFE_Py_VariableWatcherNew() {
2288 TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
2289 if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
2290 TFE_Py_VariableWatcher* variable_watcher =
2291 PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
2292 variable_watcher->variable_watcher = new VariableWatcher();
2293 Py_INCREF(variable_watcher);
2294 GetVariableWatcherSet()->insert(variable_watcher);
2295 return reinterpret_cast<PyObject*>(variable_watcher);
2296 }
2297
2298 void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
2299 auto* stack = GetVariableWatcherSet();
2300 stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
2301 // We kept a reference to the variable watcher in the set to ensure it
2302 // wouldn't get deleted under us; cleaning it up here.
2303 Py_DECREF(variable_watcher);
2304 }
2305
2306 void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
2307 for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
2308 variable_watcher->variable_watcher->WatchVariable(variable);
2309 }
2310 }
2311
2312 PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
2313 return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
2314 ->variable_watcher->GetVariablesAsPyTuple();
2315 }
2316
2317 namespace {
2318 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
2319 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2320 if (seq == nullptr) {
2321 return {};
2322 }
2323 int len = PySequence_Fast_GET_SIZE(seq);
2324 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2325 std::vector<tensorflow::DataType> list;
2326 list.reserve(len);
2327 for (int i = 0; i < len; ++i) {
2328 PyObject* tensor = seq_array[i];
2329 list.push_back(tensorflow::PyTensor_DataType(tensor));
2330 }
2331 Py_DECREF(seq);
2332 return list;
2333 }
2334
2335 PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
2336 PyObject* weak_tensor_ref) {
2337 auto* accumulator_set = GetAccumulatorSet();
2338 if (accumulator_set != nullptr) {
2339 int64_t parsed_tensor_id = MakeInt(tensor_id);
2340 for (TFE_Py_ForwardAccumulator* accumulator : *accumulator_set) {
2341 accumulator->accumulator->DeleteGradient(parsed_tensor_id);
2342 }
2343 }
2344 Py_DECREF(weak_tensor_ref);
2345 Py_DECREF(tensor_id);
2346 Py_INCREF(Py_None);
2347 return Py_None;
2348 }
2349
2350 static PyMethodDef forward_accumulator_delete_gradient_method_def = {
2351 "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
2352 METH_O, "ForwardAccumulatorDeleteGradient"};
2353
2354 void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) {
2355 tensorflow::Safe_PyObjectPtr callback(
2356 PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
2357 PyLong_FromLong(tensor_id)));
2358 // We need to keep a reference to the weakref active if we want our callback
2359 // called. The callback itself now owns the weakref object and the tensor ID
2360 // object.
2361 PyWeakref_NewRef(tensor, callback.get());
2362 }
2363
2364 void TapeSetRecordBackprop(
2365 const string& op_type, const std::vector<PyTapeTensor>& output_info,
2366 const std::vector<int64_t>& input_ids,
2367 const std::vector<tensorflow::DataType>& input_dtypes,
2368 const std::function<PyBackwardFunction*()>& backward_function_getter,
2369 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2370 tensorflow::uint64 max_gradient_tape_id) {
2371 if (!CouldBackprop()) {
2372 return;
2373 }
2374 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2375 if (tape->nesting_id < max_gradient_tape_id) {
2376 tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
2377 backward_function_getter,
2378 backward_function_killer);
2379 }
2380 }
2381 }
2382
2383 bool TapeSetRecordForwardprop(
2384 const string& op_type, PyObject* output_seq,
2385 const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
2386 const std::vector<int64_t>& input_ids,
2387 const std::vector<tensorflow::DataType>& input_dtypes,
2388 const std::function<PyBackwardFunction*()>& backward_function_getter,
2389 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2390 const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
2391 PyObject* forwardprop_output_indices,
2392 tensorflow::uint64* max_gradient_tape_id) {
2393 *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
2394 if (!CouldForwardprop()) {
2395 return true;
2396 }
2397 auto accumulator_set = SafeAccumulatorSet();
2398 tensorflow::Safe_PyObjectPtr input_seq(
2399 PySequence_Fast(input_tensors, "expected a sequence of tensors"));
2400 if (input_seq == nullptr || PyErr_Occurred()) return false;
2401 Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
2402 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2403 for (int i = 0; i < output_info.size(); ++i) {
2404 RegisterForwardAccumulatorCleanup(output_seq_array[i],
2405 output_info[i].GetID());
2406 }
2407 if (forwardprop_output_indices != nullptr &&
2408 forwardprop_output_indices != Py_None) {
2409 tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
2410 forwardprop_output_indices, "Expected a sequence of indices"));
2411 if (indices_fast == nullptr || PyErr_Occurred()) {
2412 return false;
2413 }
2414 if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
2415 accumulator_set.size()) {
2416 MaybeRaiseExceptionFromStatus(
2417 tensorflow::errors::Internal(
2418 "Accumulators were added or removed from the active set "
2419 "between packing and unpacking."),
2420 nullptr);
2421 }
2422 PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
2423 Py_ssize_t accumulator_index = 0;
2424 for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
2425 it != accumulator_set.rend(); ++it, ++accumulator_index) {
2426 tensorflow::Safe_PyObjectPtr jvp_index_seq(
2427 PySequence_Fast(indices_fast_array[accumulator_index],
2428 "Expected a sequence of jvp indices."));
2429 if (jvp_index_seq == nullptr || PyErr_Occurred()) {
2430 return false;
2431 }
2432 Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
2433 PyObject** jvp_index_seq_array =
2434 PySequence_Fast_ITEMS(jvp_index_seq.get());
2435 for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
2436 PyObject* tuple = jvp_index_seq_array[jvp_index];
2437 int64_t primal_tensor_id =
2438 output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
2439 (*it)->accumulator->Watch(
2440 primal_tensor_id,
2441 output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
2442 }
2443 }
2444 } else {
2445 std::vector<PyTapeTensor> input_info;
2446 input_info.reserve(input_len);
2447 PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
2448 for (Py_ssize_t i = 0; i < input_len; ++i) {
2449 input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
2450 }
2451 for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
2452 tensorflow::Status status = accumulator->accumulator->Accumulate(
2453 op_type, input_info, output_info, input_ids, input_dtypes,
2454 forward_function, backward_function_getter, backward_function_killer);
2455 if (PyErr_Occurred()) return false; // Don't swallow Python exceptions.
2456 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2457 return false;
2458 }
2459 if (accumulator->accumulator->BusyAccumulating()) {
2460 // Ensure inner accumulators don't see outer accumulators' jvps. This
2461 // mostly happens on its own, with some potentially surprising
2462 // exceptions, so the blanket policy is for consistency.
2463 *max_gradient_tape_id = accumulator->nesting_id;
2464 break;
2465 }
2466 }
2467 }
2468 return true;
2469 }
2470
2471 PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
2472 PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
2473 for (int i = 0; i < input_tangents.size(); ++i) {
2474 PyObject* element;
2475 if (input_tangents[i] == nullptr) {
2476 element = Py_None;
2477 } else {
2478 element = input_tangents[i];
2479 }
2480 Py_INCREF(element);
2481 PyTuple_SET_ITEM(py_input_tangents, i, element);
2482 }
2483 return py_input_tangents;
2484 }
2485
2486 tensorflow::Status ParseTangentOutputs(
2487 PyObject* user_output, std::vector<PyObject*>* output_tangents) {
2488 if (user_output == Py_None) {
2489 // No connected gradients.
2490 return ::tensorflow::OkStatus();
2491 }
2492 tensorflow::Safe_PyObjectPtr fast_result(
2493 PySequence_Fast(user_output, "expected a sequence of forward gradients"));
2494 if (fast_result == nullptr) {
2495 return tensorflow::errors::InvalidArgument(
2496 "forward gradient function did not return a sequence.");
2497 }
2498 int len = PySequence_Fast_GET_SIZE(fast_result.get());
2499 PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
2500 output_tangents->reserve(len);
2501 for (int i = 0; i < len; ++i) {
2502 PyObject* item = fast_result_array[i];
2503 if (item == Py_None) {
2504 output_tangents->push_back(nullptr);
2505 } else {
2506 Py_INCREF(item);
2507 output_tangents->push_back(item);
2508 }
2509 }
2510 return ::tensorflow::OkStatus();
2511 }
2512
2513 // Calls the registered forward_gradient_function, computing `output_tangents`
2514 // from `input_tangents`. `output_tangents` must not be null.
2515 //
2516 // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
2517 // the forward function is being called.
2518 tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
2519 PyObject* inputs, PyObject* results,
2520 const std::vector<PyObject*>& input_tangents,
2521 std::vector<PyObject*>* output_tangents,
2522 bool use_batch) {
2523 if (forward_gradient_function == nullptr) {
2524 return tensorflow::errors::Internal(
2525 "No forward gradient function registered.");
2526 }
2527 tensorflow::Safe_PyObjectPtr py_input_tangents(
2528 TangentsAsPyTuple(input_tangents));
2529
2530 // Normalize the input sequence to a tuple so it works with function
2531 // caching; otherwise it may be an opaque _InputList object.
2532 tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
2533 PyObject* to_batch = (use_batch) ? Py_True : Py_False;
2534 tensorflow::Safe_PyObjectPtr callback_args(
2535 Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
2536 py_input_tangents.get(), to_batch));
2537 tensorflow::Safe_PyObjectPtr py_result(
2538 PyObject_CallObject(forward_gradient_function, callback_args.get()));
2539 if (py_result == nullptr || PyErr_Occurred()) {
2540 return tensorflow::errors::Internal(
2541 "forward gradient function threw exceptions");
2542 }
2543 return ParseTangentOutputs(py_result.get(), output_tangents);
2544 }
2545
2546 // Like CallJVPFunction, but calls a pre-bound forward function.
2547 // These are passed in from a record_gradient argument.
2548 tensorflow::Status CallOpSpecificJVPFunction(
2549 PyObject* op_specific_forward_function,
2550 const std::vector<PyObject*>& input_tangents,
2551 std::vector<PyObject*>* output_tangents) {
2552 tensorflow::Safe_PyObjectPtr py_input_tangents(
2553 TangentsAsPyTuple(input_tangents));
2554
2555 tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
2556 op_specific_forward_function, py_input_tangents.get()));
2557 if (py_result == nullptr || PyErr_Occurred()) {
2558 return tensorflow::errors::Internal(
2559 "forward gradient function threw exceptions");
2560 }
2561 return ParseTangentOutputs(py_result.get(), output_tangents);
2562 }
2563
2564 bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
2565 if (PyBytes_Check(op_type)) {
2566 *op_type_string = PyBytes_AsString(op_type);
2567 } else if (PyUnicode_Check(op_type)) {
2568 #if PY_MAJOR_VERSION >= 3
2569 *op_type_string = PyUnicode_AsUTF8(op_type);
2570 #else
2571 PyObject* py_str = PyUnicode_AsUTF8String(op_type);
2572 if (py_str == nullptr) {
2573 return false;
2574 }
2575 *op_type_string = PyBytes_AS_STRING(py_str);
2576 Py_DECREF(py_str);
2577 #endif
2578 } else {
2579 PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
2580 return false;
2581 }
2582 return true;
2583 }
2584
2585 bool TapeSetRecordOperation(
2586 PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
2587 const std::vector<int64_t>& input_ids,
2588 const std::vector<tensorflow::DataType>& input_dtypes,
2589 const std::function<PyBackwardFunction*()>& backward_function_getter,
2590 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2591 const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
2592 std::vector<PyTapeTensor> output_info;
2593 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2594 output_tensors, "expected a sequence of integer tensor ids"));
2595 if (PyErr_Occurred() ||
2596 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2597 return false;
2598 }
2599 string op_type_str;
2600 if (!ParseOpTypeString(op_type, &op_type_str)) {
2601 return false;
2602 }
2603 tensorflow::uint64 max_gradient_tape_id;
2604 if (!TapeSetRecordForwardprop(
2605 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2606 input_dtypes, backward_function_getter, backward_function_killer,
2607 forward_function, nullptr /* No special-cased jvps. */,
2608 &max_gradient_tape_id)) {
2609 return false;
2610 }
2611 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2612 backward_function_getter, backward_function_killer,
2613 max_gradient_tape_id);
2614 return true;
2615 }
2616 } // namespace
2617
2618 PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
2619 PyObject* output_tensors,
2620 PyObject* input_tensors,
2621 PyObject* backward_function,
2622 PyObject* forward_function) {
2623 if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
2624 Py_RETURN_NONE;
2625 }
2626 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2627 if (PyErr_Occurred()) return nullptr;
2628
2629 std::vector<tensorflow::DataType> input_dtypes =
2630 MakeTensorDtypeList(input_tensors);
2631 if (PyErr_Occurred()) return nullptr;
2632
2633 std::function<PyBackwardFunction*()> backward_function_getter(
2634 [backward_function]() {
2635 Py_INCREF(backward_function);
2636 PyBackwardFunction* function = new PyBackwardFunction(
2637 [backward_function](PyObject* out_grads,
2638 const std::vector<int64_t>& unused) {
2639 return PyObject_CallObject(backward_function, out_grads);
2640 });
2641 return function;
2642 });
2643 std::function<void(PyBackwardFunction*)> backward_function_killer(
2644 [backward_function](PyBackwardFunction* py_backward_function) {
2645 Py_DECREF(backward_function);
2646 delete py_backward_function;
2647 });
2648
2649 if (forward_function == Py_None) {
2650 if (!TapeSetRecordOperation(
2651 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2652 backward_function_getter, backward_function_killer,
2653 nullptr /* No special-cased forward function */)) {
2654 return nullptr;
2655 }
2656 } else {
2657 tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
2658 [forward_function](const std::vector<PyObject*>& input_tangents,
2659 std::vector<PyObject*>* output_tangents,
2660 bool use_batch = false) {
2661 return CallOpSpecificJVPFunction(forward_function, input_tangents,
2662 output_tangents);
2663 });
2664 if (!TapeSetRecordOperation(
2665 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2666 backward_function_getter, backward_function_killer,
2667 &wrapped_forward_function)) {
2668 return nullptr;
2669 }
2670 }
2671 Py_RETURN_NONE;
2672 }
2673
2674 PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
2675 PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
2676 PyObject* backward_function, PyObject* forwardprop_output_indices) {
2677 if (!HasAccumulator() || *ThreadTapeIsStopped()) {
2678 Py_RETURN_NONE;
2679 }
2680 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2681 if (PyErr_Occurred()) return nullptr;
2682
2683 std::vector<tensorflow::DataType> input_dtypes =
2684 MakeTensorDtypeList(input_tensors);
2685 if (PyErr_Occurred()) return nullptr;
2686
2687 std::function<PyBackwardFunction*()> backward_function_getter(
2688 [backward_function]() {
2689 Py_INCREF(backward_function);
2690 PyBackwardFunction* function = new PyBackwardFunction(
2691 [backward_function](PyObject* out_grads,
2692 const std::vector<int64_t>& unused) {
2693 return PyObject_CallObject(backward_function, out_grads);
2694 });
2695 return function;
2696 });
2697 std::function<void(PyBackwardFunction*)> backward_function_killer(
2698 [backward_function](PyBackwardFunction* py_backward_function) {
2699 Py_DECREF(backward_function);
2700 delete py_backward_function;
2701 });
2702 std::vector<PyTapeTensor> output_info;
2703 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2704 output_tensors, "expected a sequence of integer tensor ids"));
2705 if (PyErr_Occurred() ||
2706 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2707 return nullptr;
2708 }
2709 string op_type_str;
2710 if (!ParseOpTypeString(op_type, &op_type_str)) {
2711 return nullptr;
2712 }
2713 tensorflow::uint64 max_gradient_tape_id;
2714 if (!TapeSetRecordForwardprop(
2715 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2716 input_dtypes, backward_function_getter, backward_function_killer,
2717 nullptr /* no special-cased forward function */,
2718 forwardprop_output_indices, &max_gradient_tape_id)) {
2719 return nullptr;
2720 }
2721 Py_RETURN_NONE;
2722 }
2723
2724 PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
2725 PyObject* output_tensors,
2726 PyObject* input_tensors,
2727 PyObject* backward_function) {
2728 if (!CouldBackprop()) {
2729 Py_RETURN_NONE;
2730 }
2731 std::vector<int64_t> input_ids = MakeTensorIDList(input_tensors);
2732 if (PyErr_Occurred()) return nullptr;
2733
2734 std::vector<tensorflow::DataType> input_dtypes =
2735 MakeTensorDtypeList(input_tensors);
2736 if (PyErr_Occurred()) return nullptr;
2737
2738 std::function<PyBackwardFunction*()> backward_function_getter(
2739 [backward_function]() {
2740 Py_INCREF(backward_function);
2741 PyBackwardFunction* function = new PyBackwardFunction(
2742 [backward_function](PyObject* out_grads,
2743 const std::vector<int64_t>& unused) {
2744 return PyObject_CallObject(backward_function, out_grads);
2745 });
2746 return function;
2747 });
2748 std::function<void(PyBackwardFunction*)> backward_function_killer(
2749 [backward_function](PyBackwardFunction* py_backward_function) {
2750 Py_DECREF(backward_function);
2751 delete py_backward_function;
2752 });
2753 std::vector<PyTapeTensor> output_info;
2754 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2755 output_tensors, "expected a sequence of integer tensor ids"));
2756 if (PyErr_Occurred() ||
2757 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2758 return nullptr;
2759 }
2760 string op_type_str;
2761 if (!ParseOpTypeString(op_type, &op_type_str)) {
2762 return nullptr;
2763 }
2764 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2765 backward_function_getter, backward_function_killer,
2766 // No filtering based on relative ordering with forward
2767 // accumulators.
2768 std::numeric_limits<tensorflow::uint64>::max());
2769 Py_RETURN_NONE;
2770 }
2771
2772 void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) {
2773 auto* tape_set = GetTapeSet();
2774 if (tape_set == nullptr) {
2775 // Current thread is being destructed, and the tape set has already
2776 // been cleared.
2777 return;
2778 }
2779 for (TFE_Py_Tape* tape : *tape_set) {
2780 tape->tape->DeleteTrace(tensor_id);
2781 }
2782 }
2783
2784 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
2785 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2786 if (seq == nullptr) {
2787 return {};
2788 }
2789 int len = PySequence_Fast_GET_SIZE(seq);
2790 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2791 std::vector<PyObject*> list(seq_array, seq_array + len);
2792 Py_DECREF(seq);
2793 return list;
2794 }
2795
2796 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
2797 PyObject* sources, PyObject* output_gradients,
2798 PyObject* sources_raw,
2799 PyObject* unconnected_gradients,
2800 TF_Status* status) {
2801 TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
2802 if (!tape_obj->tape->IsPersistent()) {
2803 auto* tape_set = GetTapeSet();
2804 if (tape_set->find(tape_obj) != tape_set->end()) {
2805 PyErr_SetString(PyExc_RuntimeError,
2806 "gradient() cannot be invoked within the "
2807 "GradientTape context (i.e., while operations are being "
2808 "recorded). Either move the call to gradient() to be "
2809 "outside the 'with tf.GradientTape' block, or "
2810 "use a persistent tape: "
2811 "'with tf.GradientTape(persistent=true)'");
2812 return nullptr;
2813 }
2814 }
2815
2816 std::vector<int64_t> target_vec = MakeTensorIDList(target);
2817 if (PyErr_Occurred()) {
2818 return nullptr;
2819 }
2820 std::vector<int64_t> sources_vec = MakeTensorIDList(sources);
2821 if (PyErr_Occurred()) {
2822 return nullptr;
2823 }
2824 tensorflow::gtl::FlatSet<int64_t> sources_set(sources_vec.begin(),
2825 sources_vec.end());
2826
2827 tensorflow::Safe_PyObjectPtr seq =
2828 tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
2829 int len = PySequence_Fast_GET_SIZE(seq.get());
2830 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2831 std::unordered_map<int64_t, PyTapeTensor> source_tensors_that_are_targets;
2832 for (int i = 0; i < len; ++i) {
2833 int64_t target_id = target_vec[i];
2834 if (sources_set.find(target_id) != sources_set.end()) {
2835 auto tensor = seq_array[i];
2836 source_tensors_that_are_targets.insert(
2837 std::make_pair(target_id, TapeTensorFromTensor(tensor)));
2838 }
2839 if (PyErr_Occurred()) {
2840 return nullptr;
2841 }
2842 }
2843 if (PyErr_Occurred()) {
2844 return nullptr;
2845 }
2846
2847 std::vector<PyObject*> outgrad_vec;
2848 if (output_gradients != Py_None) {
2849 outgrad_vec = MakeTensorList(output_gradients);
2850 if (PyErr_Occurred()) {
2851 return nullptr;
2852 }
2853 for (PyObject* tensor : outgrad_vec) {
2854 // Calling the backward function will eat a reference to the tensors in
2855 // outgrad_vec, so we need to increase their reference count.
2856 Py_INCREF(tensor);
2857 }
2858 }
2859 std::vector<PyObject*> result(sources_vec.size());
2860 status->status = tape_obj->tape->ComputeGradient(
2861 *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
2862 outgrad_vec, absl::MakeSpan(result));
2863 if (!status->status.ok()) {
2864 if (PyErr_Occurred()) {
2865 // Do not propagate the erroneous status as that would swallow the
2866 // exception which caused the problem.
2867 status->status = ::tensorflow::OkStatus();
2868 }
2869 return nullptr;
2870 }
2871
2872 bool unconnected_gradients_zero =
2873 strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
2874 std::vector<PyObject*> sources_obj;
2875 if (unconnected_gradients_zero) {
2876 // Uses the "raw" sources here so it can properly make a zeros tensor even
2877 // if there are resource variables as sources.
2878 sources_obj = MakeTensorList(sources_raw);
2879 }
2880
2881 if (!result.empty()) {
2882 PyObject* py_result = PyList_New(result.size());
2883 tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
2884 for (int i = 0; i < result.size(); ++i) {
2885 if (result[i] == nullptr) {
2886 if (unconnected_gradients_zero) {
2887 // generate a zeros tensor in the shape of sources[i]
2888 tensorflow::DataType dtype =
2889 tensorflow::PyTensor_DataType(sources_obj[i]);
2890 PyTapeTensor tensor =
2891 PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
2892 result[i] = tensor.ZerosLike();
2893 } else {
2894 Py_INCREF(Py_None);
2895 result[i] = Py_None;
2896 }
2897 } else if (seen_results.find(result[i]) != seen_results.end()) {
2898 Py_INCREF(result[i]);
2899 }
2900 seen_results.insert(result[i]);
2901 PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
2902 }
2903 return py_result;
2904 }
2905 return PyList_New(0);
2906 }
2907
2908 PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
2909 TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
2910 if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
2911 TFE_Py_ForwardAccumulator* accumulator =
2912 PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
2913 if (py_vspace == nullptr) {
2914 MaybeRaiseExceptionFromStatus(
2915 tensorflow::errors::Internal(
2916 "ForwardAccumulator requires a PyVSpace to be registered."),
2917 nullptr);
2918 }
2919 accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
2920 return reinterpret_cast<PyObject*>(accumulator);
2921 }
2922
2923 PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
2924 TFE_Py_ForwardAccumulator* c_accumulator(
2925 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2926 c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
2927 if (GetAccumulatorSet()->insert(c_accumulator)) {
2928 Py_INCREF(accumulator);
2929 Py_RETURN_NONE;
2930 } else {
2931 MaybeRaiseExceptionFromStatus(
2932 tensorflow::errors::Internal(
2933 "A ForwardAccumulator was added to the active set twice."),
2934 nullptr);
2935 return nullptr;
2936 }
2937 }
2938
2939 void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
2940 auto* accumulator_set = GetAccumulatorSet();
2941 if (accumulator_set != nullptr) {
2942 accumulator_set->erase(
2943 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2944 }
2945 Py_DECREF(accumulator);
2946 }
2947
2948 void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
2949 PyObject* tangent) {
2950 int64_t tensor_id = FastTensorId(tensor);
2951 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2952 ->accumulator->Watch(tensor_id, tangent);
2953 RegisterForwardAccumulatorCleanup(tensor, tensor_id);
2954 }
2955
2956 // Returns a new reference to the JVP Tensor.
2957 PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
2958 PyObject* tensor) {
2959 PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2960 ->accumulator->FetchJVP(FastTensorId(tensor));
2961 if (jvp == nullptr) {
2962 jvp = Py_None;
2963 }
2964 Py_INCREF(jvp);
2965 return jvp;
2966 }
2967
2968 PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
2969 if (!TapeCouldPossiblyRecord(tensors)) {
2970 tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
2971 tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
2972 return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
2973 }
2974 auto& accumulators = *GetAccumulatorSet();
2975 tensorflow::Safe_PyObjectPtr tensors_fast(
2976 PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
2977 if (tensors_fast == nullptr || PyErr_Occurred()) {
2978 return nullptr;
2979 }
2980 std::vector<int64_t> augmented_input_ids;
2981 int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
2982 PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
2983 for (Py_ssize_t position = 0; position < len; ++position) {
2984 PyObject* input = tensors_fast_array[position];
2985 if (input == Py_None) {
2986 continue;
2987 }
2988 tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
2989 if (input_dtype == tensorflow::DT_INVALID) {
2990 return nullptr;
2991 }
2992 augmented_input_ids.push_back(FastTensorId(input));
2993 }
2994 if (PyErr_Occurred()) {
2995 return nullptr;
2996 }
2997 // Find the innermost accumulator such that all outer accumulators are
2998 // recording. Any more deeply nested accumulators will not have their JVPs
2999 // saved.
3000 AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
3001 for (; innermost_all_recording != accumulators.end();
3002 ++innermost_all_recording) {
3003 if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
3004 break;
3005 }
3006 }
3007 AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
3008 innermost_all_recording);
3009
3010 bool saving_jvps = false;
3011 tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
3012 std::vector<PyObject*> new_tensors;
3013 Py_ssize_t accumulator_index = 0;
3014 // Start with the innermost accumulators to give outer accumulators a chance
3015 // to find their higher-order JVPs.
3016 for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
3017 it != accumulators.rend(); ++it, ++accumulator_index) {
3018 std::vector<int64_t> new_input_ids;
3019 std::vector<std::pair<int64_t, int64_t>> accumulator_indices;
3020 if (it == reverse_innermost_all_recording) {
3021 saving_jvps = true;
3022 }
3023 if (saving_jvps) {
3024 for (int input_index = 0; input_index < augmented_input_ids.size();
3025 ++input_index) {
3026 int64_t existing_input = augmented_input_ids[input_index];
3027 PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
3028 if (jvp != nullptr) {
3029 new_tensors.push_back(jvp);
3030 new_input_ids.push_back(FastTensorId(jvp));
3031 accumulator_indices.emplace_back(
3032 input_index,
3033 augmented_input_ids.size() + new_input_ids.size() - 1);
3034 }
3035 }
3036 }
3037 tensorflow::Safe_PyObjectPtr accumulator_indices_py(
3038 PyTuple_New(accumulator_indices.size()));
3039 for (int i = 0; i < accumulator_indices.size(); ++i) {
3040 tensorflow::Safe_PyObjectPtr from_index(
3041 GetPythonObjectFromInt(accumulator_indices[i].first));
3042 tensorflow::Safe_PyObjectPtr to_index(
3043 GetPythonObjectFromInt(accumulator_indices[i].second));
3044 PyTuple_SetItem(accumulator_indices_py.get(), i,
3045 PyTuple_Pack(2, from_index.get(), to_index.get()));
3046 }
3047 PyTuple_SetItem(all_indices.get(), accumulator_index,
3048 accumulator_indices_py.release());
3049 augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
3050 new_input_ids.end());
3051 }
3052
3053 tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
3054 for (int i = 0; i < new_tensors.size(); ++i) {
3055 PyObject* jvp = new_tensors[i];
3056 Py_INCREF(jvp);
3057 PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
3058 }
3059 return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
3060 }
3061
3062 namespace {
3063
3064 // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
3065 enum FastPathExecuteArgIndex {
3066 FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
3067 FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
3068 FAST_PATH_EXECUTE_ARG_NAME = 2,
3069 FAST_PATH_EXECUTE_ARG_INPUT_START = 3
3070 };
3071
3072 PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
3073 #if PY_MAJOR_VERSION >= 3
3074 return PyUnicode_FromStringAndSize(s.data(), s.size());
3075 #else
3076 return PyBytes_FromStringAndSize(s.data(), s.size());
3077 #endif
3078 }
3079
3080 bool CheckResourceVariable(PyObject* item) {
3081 if (tensorflow::swig::IsResourceVariable(item)) {
3082 tensorflow::Safe_PyObjectPtr handle(
3083 PyObject_GetAttrString(item, "_handle"));
3084 return EagerTensor_CheckExact(handle.get());
3085 }
3086
3087 return false;
3088 }
3089
3090 bool IsNumberType(PyObject* item) {
3091 #if PY_MAJOR_VERSION >= 3
3092 return PyFloat_Check(item) || PyLong_Check(item);
3093 #else
3094 return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
3095 #endif
3096 }
3097
3098 bool CheckOneInput(PyObject* item) {
3099 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
3100 PyArray_Check(item) || IsNumberType(item)) {
3101 return true;
3102 }
3103
3104 // Sequences are not properly handled. Sequences with purely python numeric
3105 // types work, but sequences with mixes of EagerTensors and python numeric
3106 // types don't work.
3107 // TODO(nareshmodi): fix
3108 return false;
3109 }
3110
3111 bool CheckInputsOk(PyObject* seq, int start_index,
3112 const tensorflow::OpDef& op_def) {
3113 for (int i = 0; i < op_def.input_arg_size(); i++) {
3114 PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
3115 if (!op_def.input_arg(i).number_attr().empty() ||
3116 !op_def.input_arg(i).type_list_attr().empty()) {
3117 // This item should be a seq input.
3118 if (!PySequence_Check(item)) {
3119 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3120 << "\", Input \"" << op_def.input_arg(i).name()
3121 << "\" since we expected a sequence, but got "
3122 << item->ob_type->tp_name;
3123 return false;
3124 }
3125 tensorflow::Safe_PyObjectPtr fast_item(
3126 PySequence_Fast(item, "Could not parse sequence."));
3127 if (fast_item.get() == nullptr) {
3128 return false;
3129 }
3130 int len = PySequence_Fast_GET_SIZE(fast_item.get());
3131 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3132 for (Py_ssize_t j = 0; j < len; j++) {
3133 PyObject* inner_item = fast_item_array[j];
3134 if (!CheckOneInput(inner_item)) {
3135 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3136 << "\", Input \"" << op_def.input_arg(i).name()
3137 << "\", Index " << j
3138 << " since we expected an EagerTensor/ResourceVariable, "
3139 "but got "
3140 << inner_item->ob_type->tp_name;
3141 return false;
3142 }
3143 }
3144 } else if (!CheckOneInput(item)) {
3145 VLOG(1)
3146 << "Falling back to slow path for Op \"" << op_def.name()
3147 << "\", Input \"" << op_def.input_arg(i).name()
3148 << "\" since we expected an EagerTensor/ResourceVariable, but got "
3149 << item->ob_type->tp_name;
3150 return false;
3151 }
3152 }
3153
3154 return true;
3155 }
3156
3157 tensorflow::DataType MaybeGetDType(PyObject* item) {
3158 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
3159 return tensorflow::PyTensor_DataType(item);
3160 }
3161
3162 return tensorflow::DT_INVALID;
3163 }
3164
3165 tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
3166 FastPathOpExecInfo* op_exec_info) {
3167 auto cached_it = op_exec_info->cached_dtypes.find(attr);
3168 if (cached_it != op_exec_info->cached_dtypes.end()) {
3169 return cached_it->second;
3170 }
3171
3172 auto it = op_exec_info->attr_to_inputs_map->find(attr);
3173 if (it == op_exec_info->attr_to_inputs_map->end()) {
3174 // No other inputs - this should never happen.
3175 return tensorflow::DT_INVALID;
3176 }
3177
3178 for (const auto& input_info : it->second) {
3179 PyObject* item = PyTuple_GET_ITEM(
3180 op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
3181 if (input_info.is_list) {
3182 tensorflow::Safe_PyObjectPtr fast_item(
3183 PySequence_Fast(item, "Unable to allocate"));
3184 int len = PySequence_Fast_GET_SIZE(fast_item.get());
3185 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3186 for (int i = 0; i < len; i++) {
3187 auto dtype = MaybeGetDType(fast_item_array[i]);
3188 if (dtype != tensorflow::DT_INVALID) return dtype;
3189 }
3190 } else {
3191 auto dtype = MaybeGetDType(item);
3192 if (dtype != tensorflow::DT_INVALID) return dtype;
3193 }
3194 }
3195
3196 auto default_it = op_exec_info->default_dtypes->find(attr);
3197 if (default_it != op_exec_info->default_dtypes->end()) {
3198 return default_it->second;
3199 }
3200
3201 return tensorflow::DT_INVALID;
3202 }
3203
3204 PyObject* CopySequenceSettingIndicesToNull(
3205 PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
3206 tensorflow::Safe_PyObjectPtr fast_seq(
3207 PySequence_Fast(seq, "unable to allocate"));
3208 int len = PySequence_Fast_GET_SIZE(fast_seq.get());
3209 PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
3210 PyObject* result = PyTuple_New(len);
3211 for (int i = 0; i < len; i++) {
3212 PyObject* item;
3213 if (indices.find(i) != indices.end()) {
3214 item = Py_None;
3215 } else {
3216 item = fast_seq_array[i];
3217 }
3218 Py_INCREF(item);
3219 PyTuple_SET_ITEM(result, i, item);
3220 }
3221 return result;
3222 }
3223
3224 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
3225 PyObject* results,
3226 PyObject* forward_pass_name_scope = nullptr) {
3227 std::vector<int64_t> input_ids = MakeTensorIDList(inputs);
3228 if (PyErr_Occurred()) return nullptr;
3229 std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
3230 if (PyErr_Occurred()) return nullptr;
3231
3232 bool should_record = false;
3233 for (TFE_Py_Tape* tape : SafeTapeSet()) {
3234 if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
3235 should_record = true;
3236 break;
3237 }
3238 }
3239 if (!should_record) {
3240 for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
3241 if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
3242 should_record = true;
3243 break;
3244 }
3245 }
3246 }
3247 if (!should_record) Py_RETURN_NONE;
3248
3249 string c_op_name = TFE_GetPythonString(op_name);
3250
3251 PyObject* op_outputs;
3252 bool op_outputs_tuple_created = false;
3253
3254 if (const auto unused_output_indices =
3255 OpGradientUnusedOutputIndices(c_op_name)) {
3256 if (unused_output_indices->empty()) {
3257 op_outputs = Py_None;
3258 } else {
3259 op_outputs_tuple_created = true;
3260 op_outputs =
3261 CopySequenceSettingIndicesToNull(results, *unused_output_indices);
3262 }
3263 } else {
3264 op_outputs = results;
3265 }
3266
3267 PyObject* op_inputs;
3268 bool op_inputs_tuple_created = false;
3269
3270 if (const auto unused_input_indices =
3271 OpGradientUnusedInputIndices(c_op_name)) {
3272 if (unused_input_indices->empty()) {
3273 op_inputs = Py_None;
3274 } else {
3275 op_inputs_tuple_created = true;
3276 op_inputs =
3277 CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
3278 }
3279 } else {
3280 op_inputs = inputs;
3281 }
3282
3283 tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
3284 [op_name, attrs, inputs, results](
3285 const std::vector<PyObject*>& input_tangents,
3286 std::vector<PyObject*>* output_tangents, bool use_batch) {
3287 return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
3288 output_tangents, use_batch);
3289 });
3290 tensorflow::eager::ForwardFunction<PyObject>* forward_function;
3291 if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
3292 c_op_name == "If" || c_op_name == "StatelessIf") {
3293 // Control flow contains non-hashable attributes. Handling them in Python is
3294 // a headache, so instead we'll stay as close to GradientTape's handling as
3295 // possible (a null forward function means the accumulator forwards to a
3296 // tape).
3297 //
3298 // This is safe to do since we'll only see control flow when graph building,
3299 // in which case we can rely on pruning.
3300 forward_function = nullptr;
3301 } else {
3302 forward_function = &py_forward_function;
3303 }
3304
3305 PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
3306
3307 if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
3308
3309 TapeSetRecordOperation(
3310 op_name, inputs, results, input_ids, input_dtypes,
3311 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3312 forward_pass_name_scope]() {
3313 Py_INCREF(op_name);
3314 Py_INCREF(attrs);
3315 Py_INCREF(num_inputs);
3316 Py_INCREF(op_inputs);
3317 Py_INCREF(op_outputs);
3318 Py_INCREF(forward_pass_name_scope);
3319 PyBackwardFunction* function = new PyBackwardFunction(
3320 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3321 forward_pass_name_scope](
3322 PyObject* output_grads,
3323 const std::vector<int64_t>& unneeded_gradients) {
3324 if (PyErr_Occurred()) {
3325 return static_cast<PyObject*>(nullptr);
3326 }
3327 tensorflow::Safe_PyObjectPtr skip_input_indices;
3328 if (!unneeded_gradients.empty()) {
3329 skip_input_indices.reset(
3330 PyTuple_New(unneeded_gradients.size()));
3331 for (int i = 0; i < unneeded_gradients.size(); i++) {
3332 PyTuple_SET_ITEM(
3333 skip_input_indices.get(), i,
3334 GetPythonObjectFromInt(unneeded_gradients[i]));
3335 }
3336 } else {
3337 Py_INCREF(Py_None);
3338 skip_input_indices.reset(Py_None);
3339 }
3340 tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3341 "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3342 output_grads, skip_input_indices.get(),
3343 forward_pass_name_scope));
3344
3345 tensorflow::Safe_PyObjectPtr result(
3346 PyObject_CallObject(gradient_function, callback_args.get()));
3347
3348 if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
3349
3350 return tensorflow::swig::Flatten(result.get());
3351 });
3352 return function;
3353 },
3354 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3355 forward_pass_name_scope](PyBackwardFunction* backward_function) {
3356 Py_DECREF(op_name);
3357 Py_DECREF(attrs);
3358 Py_DECREF(num_inputs);
3359 Py_DECREF(op_inputs);
3360 Py_DECREF(op_outputs);
3361 Py_DECREF(forward_pass_name_scope);
3362
3363 delete backward_function;
3364 },
3365 forward_function);
3366
3367 Py_DECREF(num_inputs);
3368 if (op_outputs_tuple_created) Py_DECREF(op_outputs);
3369 if (op_inputs_tuple_created) Py_DECREF(op_inputs);
3370
3371 if (PyErr_Occurred()) {
3372 return nullptr;
3373 }
3374
3375 Py_RETURN_NONE;
3376 }
3377
3378 void MaybeNotifyVariableAccessed(PyObject* input) {
3379 DCHECK(CheckResourceVariable(input));
3380 DCHECK(PyObject_HasAttrString(input, "_trainable"));
3381
3382 tensorflow::Safe_PyObjectPtr trainable(
3383 PyObject_GetAttrString(input, "_trainable"));
3384 if (trainable.get() == Py_False) return;
3385 TFE_Py_TapeVariableAccessed(input);
3386 TFE_Py_VariableWatcherVariableAccessed(input);
3387 }
3388
3389 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
3390 PyObject* input, tensorflow::Safe_PyObjectPtr* output,
3391 TF_Status* status) {
3392 MaybeNotifyVariableAccessed(input);
3393
3394 TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
3395 auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
3396 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3397 return false;
3398
3399 TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
3400 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3401 return false;
3402
3403 // Set dtype
3404 DCHECK(PyObject_HasAttrString(input, "_dtype"));
3405 tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
3406 int value;
3407 if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
3408 return false;
3409 }
3410 TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
3411
3412 // Get handle
3413 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
3414 if (!EagerTensor_CheckExact(handle.get())) return false;
3415 TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
3416 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3417 return false;
3418
3419 int num_retvals = 1;
3420 TFE_TensorHandle* output_handle;
3421 TFE_Execute(op, &output_handle, &num_retvals, status);
3422 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr))
3423 return false;
3424
3425 // Always create the py object (and correctly DECREF it) from the returned
3426 // value, else the data will leak.
3427 output->reset(EagerTensorFromHandle(output_handle));
3428
3429 // TODO(nareshmodi): Should we run post exec callbacks here?
3430 if (parent_op_exec_info.run_gradient_callback) {
3431 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
3432 PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
3433
3434 tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
3435 Py_INCREF(output->get()); // stay alive after since tuple steals.
3436 PyTuple_SET_ITEM(outputs.get(), 0, output->get());
3437
3438 tensorflow::Safe_PyObjectPtr op_string(
3439 GetPythonObjectFromString("ReadVariableOp"));
3440 if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
3441 outputs.get())) {
3442 return false;
3443 }
3444 }
3445
3446 return true;
3447 }
3448
3449 // Supports 3 cases at the moment:
3450 // i) input is an EagerTensor.
3451 // ii) input is a ResourceVariable - in this case, the is_variable param is
3452 // set to true.
3453 // iii) input is an arbitrary python list/tuple (note, this handling doesn't
3454 // support packing).
3455 //
3456 // NOTE: dtype_hint_getter must *always* return a PyObject that can be
3457 // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
3458 // increfs Py_None).
3459 //
3460 // NOTE: This function sets a python error directly, and returns false.
3461 // TF_Status is only passed since we don't want to have to reallocate it.
3462 bool ConvertToTensor(
3463 const FastPathOpExecInfo& op_exec_info, PyObject* input,
3464 tensorflow::Safe_PyObjectPtr* output_handle,
3465 // This gets a hint for this particular input.
3466 const std::function<tensorflow::DataType()>& dtype_hint_getter,
3467 // This sets the dtype after conversion is complete.
3468 const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
3469 TF_Status* status) {
3470 if (EagerTensor_CheckExact(input)) {
3471 Py_INCREF(input);
3472 output_handle->reset(input);
3473 return true;
3474 } else if (CheckResourceVariable(input)) {
3475 return ReadVariableOp(op_exec_info, input, output_handle, status);
3476 }
3477
3478 // The hint comes from a supposedly similarly typed tensor.
3479 tensorflow::DataType dtype_hint = dtype_hint_getter();
3480
3481 TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
3482 op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
3483 if (handle == nullptr) {
3484 return tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr);
3485 }
3486
3487 output_handle->reset(EagerTensorFromHandle(handle));
3488 dtype_setter(
3489 static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
3490
3491 return true;
3492 }
3493
3494 // Adds input and type attr to the op, and to the list of flattened
3495 // inputs/attrs.
3496 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
3497 const bool add_type_attr,
3498 const tensorflow::OpDef::ArgDef& input_arg,
3499 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
3500 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
3501 TFE_Op* op, TF_Status* status) {
3502 // py_eager_tensor's ownership is transferred to flattened_inputs if it is
3503 // required, else the object is destroyed and DECREF'd when the object goes
3504 // out of scope in this function.
3505 tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
3506
3507 if (!ConvertToTensor(
3508 *op_exec_info, input, &py_eager_tensor,
3509 [&]() {
3510 if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
3511 return input_arg.type();
3512 }
3513 return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
3514 },
3515 [&](const tensorflow::DataType dtype) {
3516 op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
3517 },
3518 status)) {
3519 return false;
3520 }
3521
3522 TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
3523
3524 if (add_type_attr && !input_arg.type_attr().empty()) {
3525 auto dtype = TFE_TensorHandleDataType(input_handle);
3526 TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
3527 if (flattened_attrs != nullptr) {
3528 flattened_attrs->emplace_back(
3529 GetPythonObjectFromString(input_arg.type_attr()));
3530 flattened_attrs->emplace_back(PyLong_FromLong(dtype));
3531 }
3532 }
3533
3534 if (flattened_inputs != nullptr) {
3535 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3536 }
3537
3538 TFE_OpAddInput(op, input_handle, status);
3539 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3540 return false;
3541 }
3542
3543 return true;
3544 }
3545
3546 const char* GetDeviceName(PyObject* py_device_name) {
3547 if (py_device_name != Py_None) {
3548 return TFE_GetPythonString(py_device_name);
3549 }
3550 return nullptr;
3551 }
3552
3553 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
3554 if (!PySequence_Check(seq)) {
3555 PyErr_SetString(PyExc_TypeError,
3556 Printf("expected a sequence for attr %s, got %s instead",
3557 attr_name.data(), seq->ob_type->tp_name)
3558 .data());
3559
3560 return false;
3561 }
3562 if (PyArray_Check(seq) &&
3563 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
3564 PyErr_SetString(PyExc_ValueError,
3565 Printf("expected a sequence for attr %s, got an ndarray "
3566 "with rank %d instead",
3567 attr_name.data(),
3568 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
3569 .data());
3570 return false;
3571 }
3572 return true;
3573 }
3574
3575 bool RunCallbacks(
3576 const FastPathOpExecInfo& op_exec_info, PyObject* args,
3577 int num_inferred_attrs,
3578 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
3579 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
3580 PyObject* flattened_result) {
3581 DCHECK(op_exec_info.run_callbacks);
3582
3583 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
3584 for (int i = 0; i < flattened_inputs.size(); i++) {
3585 PyObject* input = flattened_inputs[i].get();
3586 Py_INCREF(input);
3587 PyTuple_SET_ITEM(inputs.get(), i, input);
3588 }
3589
3590 int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
3591 int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
3592 tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
3593
3594 for (int i = 0; i < num_non_inferred_attrs; i++) {
3595 auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
3596 Py_INCREF(attr);
3597 PyTuple_SET_ITEM(attrs.get(), i, attr);
3598 }
3599
3600 for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
3601 PyObject* attr_or_name =
3602 flattened_attrs.at(i - num_non_inferred_attrs).get();
3603 Py_INCREF(attr_or_name);
3604 PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
3605 }
3606
3607 if (op_exec_info.run_gradient_callback) {
3608 if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
3609 flattened_result)) {
3610 return false;
3611 }
3612 }
3613
3614 if (op_exec_info.run_post_exec_callbacks) {
3615 tensorflow::Safe_PyObjectPtr callback_args(
3616 Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
3617 flattened_result, op_exec_info.name));
3618 for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
3619 PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
3620 if (!PyCallable_Check(callback_fn)) {
3621 PyErr_SetString(
3622 PyExc_TypeError,
3623 Printf("expected a function for "
3624 "post execution callback in index %ld, got %s instead",
3625 i, callback_fn->ob_type->tp_name)
3626 .c_str());
3627 return false;
3628 }
3629 PyObject* callback_result =
3630 PyObject_CallObject(callback_fn, callback_args.get());
3631 if (!callback_result) {
3632 return false;
3633 }
3634 Py_DECREF(callback_result);
3635 }
3636 }
3637
3638 return true;
3639 }
3640
3641 } // namespace
3642
3643 PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
3644 tensorflow::profiler::TraceMe activity(
3645 "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
3646 Py_ssize_t args_size = PyTuple_GET_SIZE(args);
3647 if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
3648 PyErr_SetString(
3649 PyExc_ValueError,
3650 Printf("There must be at least %d items in the input tuple.",
3651 FAST_PATH_EXECUTE_ARG_INPUT_START)
3652 .c_str());
3653 return nullptr;
3654 }
3655
3656 FastPathOpExecInfo op_exec_info;
3657
3658 PyObject* py_eager_context =
3659 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
3660
3661 // TODO(edoper): Use interned string here
3662 PyObject* eager_context_handle =
3663 PyObject_GetAttrString(py_eager_context, "_context_handle");
3664
3665 TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
3666 PyCapsule_GetPointer(eager_context_handle, nullptr));
3667 op_exec_info.ctx = ctx;
3668 op_exec_info.args = args;
3669
3670 if (ctx == nullptr) {
3671 // The context hasn't been initialized. It will be in the slow path.
3672 RaiseFallbackException(
3673 "This function does not handle the case of the path where "
3674 "all inputs are not already EagerTensors.");
3675 return nullptr;
3676 }
3677
3678 auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
3679 if (tld == nullptr) {
3680 return nullptr;
3681 }
3682 op_exec_info.device_name = GetDeviceName(tld->device_name.get());
3683 op_exec_info.callbacks = tld->op_callbacks.get();
3684
3685 op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
3686 op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
3687
3688 // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
3689 // (similar to benchmark_tf_gradient_function_*). Also consider using an
3690 // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
3691 // point out problems with heap allocs.
3692 op_exec_info.run_gradient_callback =
3693 !*ThreadTapeIsStopped() && HasAccumulatorOrTape();
3694 op_exec_info.run_post_exec_callbacks =
3695 op_exec_info.callbacks != Py_None &&
3696 PyList_Size(op_exec_info.callbacks) > 0;
3697 op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
3698 op_exec_info.run_post_exec_callbacks;
3699
3700 TF_Status* status = GetStatus();
3701 const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
3702 if (op_name == nullptr) {
3703 PyErr_SetString(PyExc_TypeError,
3704 Printf("expected a string for op_name, got %s instead",
3705 op_exec_info.op_name->ob_type->tp_name)
3706 .c_str());
3707 return nullptr;
3708 }
3709
3710 TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
3711
3712 auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
3713 ReturnStatus(status);
3714 ReturnOp(ctx, op);
3715 });
3716
3717 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3718 return nullptr;
3719 }
3720
3721 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
3722 tensorflow::StackTrace::kStackTraceInitialSize));
3723
3724 const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
3725 if (op_def == nullptr) return nullptr;
3726
3727 if (args_size <
3728 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
3729 PyErr_SetString(
3730 PyExc_ValueError,
3731 Printf("Tuple size smaller than intended. Expected to be at least %d, "
3732 "was %ld",
3733 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3734 args_size)
3735 .c_str());
3736 return nullptr;
3737 }
3738
3739 if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
3740 RaiseFallbackException(
3741 "This function does not handle the case of the path where "
3742 "all inputs are not already EagerTensors.");
3743 return nullptr;
3744 }
3745
3746 op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
3747 op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
3748
3749 // Mapping of attr name to size - used to calculate the number of values
3750 // to be expected by the TFE_Execute run.
3751 tensorflow::gtl::FlatMap<string, int64_t> attr_list_sizes;
3752
3753 // Set non-inferred attrs, including setting defaults if the attr is passed in
3754 // as None.
3755 for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
3756 i < args_size; i += 2) {
3757 PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
3758 const char* attr_name = TFE_GetPythonString(py_attr_name);
3759 PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
3760
3761 // Not creating an index since most of the time there are not more than a
3762 // few attrs.
3763 // TODO(nareshmodi): Maybe include the index as part of the
3764 // OpRegistrationData.
3765 for (const auto& attr : op_def->attr()) {
3766 if (tensorflow::StringPiece(attr_name) == attr.name()) {
3767 SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
3768 &attr_list_sizes, status);
3769
3770 if (!status->status.ok()) {
3771 VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
3772 << "\" since we are unable to set the value for attr \""
3773 << attr.name() << "\" due to: " << TF_Message(status);
3774 RaiseFallbackException(TF_Message(status));
3775 return nullptr;
3776 }
3777
3778 break;
3779 }
3780 }
3781 }
3782
3783 // Flat attrs and inputs as required by the record_gradient call. The attrs
3784 // here only contain inferred attrs (non-inferred attrs are added directly
3785 // from the input args).
3786 // All items in flattened_attrs and flattened_inputs contain
3787 // Safe_PyObjectPtr - any time something steals a reference to this, it must
3788 // INCREF.
3789 // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
3790 // directly.
3791 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
3792 nullptr;
3793 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
3794 nullptr;
3795
3796 // TODO(nareshmodi): Encapsulate callbacks information into a struct.
3797 if (op_exec_info.run_callbacks) {
3798 flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3799 flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3800 }
3801
3802 // Add inferred attrs and inputs.
3803 // The following code might set duplicate type attrs. This will result in
3804 // the CacheKey for the generated AttrBuilder possibly differing from
3805 // those where the type attrs are correctly set. Inconsistent CacheKeys
3806 // for ops means that there might be unnecessarily duplicated kernels.
3807 // TODO(nareshmodi): Fix this.
3808 for (int i = 0; i < op_def->input_arg_size(); i++) {
3809 const auto& input_arg = op_def->input_arg(i);
3810
3811 PyObject* input =
3812 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
3813 if (!input_arg.number_attr().empty()) {
3814 // The item is a homogeneous list.
3815 if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
3816 tensorflow::Safe_PyObjectPtr fast_input(
3817 PySequence_Fast(input, "Could not parse sequence."));
3818 if (fast_input.get() == nullptr) {
3819 return nullptr;
3820 }
3821 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3822 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3823
3824 TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
3825 if (op_exec_info.run_callbacks) {
3826 flattened_attrs->emplace_back(
3827 GetPythonObjectFromString(input_arg.number_attr()));
3828 flattened_attrs->emplace_back(PyLong_FromLong(len));
3829 }
3830 attr_list_sizes[input_arg.number_attr()] = len;
3831
3832 if (len > 0) {
3833 // First item adds the type attr.
3834 if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
3835 flattened_attrs.get(), flattened_inputs.get(), op,
3836 status)) {
3837 return nullptr;
3838 }
3839
3840 for (Py_ssize_t j = 1; j < len; j++) {
3841 // Since the list is homogeneous, we don't need to re-add the attr.
3842 if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
3843 input_arg, nullptr /* flattened_attrs */,
3844 flattened_inputs.get(), op, status)) {
3845 return nullptr;
3846 }
3847 }
3848 }
3849 } else if (!input_arg.type_list_attr().empty()) {
3850 // The item is a heterogeneous list.
3851 if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
3852 return nullptr;
3853 }
3854 tensorflow::Safe_PyObjectPtr fast_input(
3855 PySequence_Fast(input, "Could not parse sequence."));
3856 if (fast_input.get() == nullptr) {
3857 return nullptr;
3858 }
3859 const string& attr_name = input_arg.type_list_attr();
3860 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3861 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3862 tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
3863 PyObject* py_attr_value = nullptr;
3864 if (op_exec_info.run_callbacks) {
3865 py_attr_value = PyTuple_New(len);
3866 }
3867 for (Py_ssize_t j = 0; j < len; j++) {
3868 PyObject* py_input = fast_input_array[j];
3869 tensorflow::Safe_PyObjectPtr py_eager_tensor;
3870 if (!ConvertToTensor(
3871 op_exec_info, py_input, &py_eager_tensor,
3872 []() { return tensorflow::DT_INVALID; },
3873 [](const tensorflow::DataType dtype) {}, status)) {
3874 return nullptr;
3875 }
3876
3877 TFE_TensorHandle* input_handle =
3878 EagerTensor_Handle(py_eager_tensor.get());
3879
3880 attr_value[j] = TFE_TensorHandleDataType(input_handle);
3881
3882 TFE_OpAddInput(op, input_handle, status);
3883 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3884 return nullptr;
3885 }
3886
3887 if (op_exec_info.run_callbacks) {
3888 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3889
3890 PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
3891 }
3892 }
3893 if (op_exec_info.run_callbacks) {
3894 flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
3895 flattened_attrs->emplace_back(py_attr_value);
3896 }
3897 TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
3898 attr_value.size());
3899 attr_list_sizes[attr_name] = len;
3900 } else {
3901 // The item is a single item.
3902 if (!AddInputToOp(&op_exec_info, input, true, input_arg,
3903 flattened_attrs.get(), flattened_inputs.get(), op,
3904 status)) {
3905 return nullptr;
3906 }
3907 }
3908 }
3909
3910 int64_t num_outputs = 0;
3911 for (int i = 0; i < op_def->output_arg_size(); i++) {
3912 const auto& output_arg = op_def->output_arg(i);
3913 int64_t delta = 1;
3914 if (!output_arg.number_attr().empty()) {
3915 delta = attr_list_sizes[output_arg.number_attr()];
3916 } else if (!output_arg.type_list_attr().empty()) {
3917 delta = attr_list_sizes[output_arg.type_list_attr()];
3918 }
3919 if (delta < 0) {
3920 RaiseFallbackException(
3921 "Attributes suggest that the size of an output list is less than 0");
3922 return nullptr;
3923 }
3924 num_outputs += delta;
3925 }
3926
3927 // If number of retvals is larger than int32, we error out.
3928 if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
3929 PyErr_SetString(
3930 PyExc_ValueError,
3931 Printf("Number of outputs is too big: %ld", num_outputs).c_str());
3932 return nullptr;
3933 }
3934 int num_retvals = num_outputs;
3935
3936 tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
3937
3938 Py_BEGIN_ALLOW_THREADS;
3939 TFE_Execute(op, retvals.data(), &num_retvals, status);
3940 Py_END_ALLOW_THREADS;
3941
3942 if (!status->status.ok()) {
3943 // Augment the status with the op_name for easier debugging similar to
3944 // TFE_Py_Execute.
3945 status->status = tensorflow::errors::CreateWithUpdatedMessage(
3946 status->status, tensorflow::strings::StrCat(
3947 TF_Message(status), " [Op:",
3948 TFE_GetPythonString(op_exec_info.op_name), "]"));
3949 tensorflow::MaybeRaiseExceptionFromTFStatus(status, nullptr);
3950 return nullptr;
3951 }
3952
3953 tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
3954 for (int i = 0; i < num_retvals; ++i) {
3955 PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
3956 }
3957
3958 if (op_exec_info.run_callbacks) {
3959 if (!RunCallbacks(
3960 op_exec_info, args,
3961 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3962 *flattened_inputs, *flattened_attrs, flat_result.get())) {
3963 return nullptr;
3964 }
3965 }
3966
3967 // Unflatten results.
3968 if (op_def->output_arg_size() == 0) {
3969 Py_RETURN_NONE;
3970 }
3971
3972 if (op_def->output_arg_size() == 1) {
3973 if (!op_def->output_arg(0).number_attr().empty() ||
3974 !op_def->output_arg(0).type_list_attr().empty()) {
3975 return flat_result.release();
3976 } else {
3977 auto* result = PyList_GET_ITEM(flat_result.get(), 0);
3978 Py_INCREF(result);
3979 return result;
3980 }
3981 }
3982
3983 // Correctly output the results that are made into a namedtuple.
3984 PyObject* result = PyList_New(op_def->output_arg_size());
3985 int flat_result_index = 0;
3986 for (int i = 0; i < op_def->output_arg_size(); i++) {
3987 if (!op_def->output_arg(i).number_attr().empty()) {
3988 int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
3989 PyObject* inner_list = PyList_New(list_length);
3990 for (int j = 0; j < list_length; j++) {
3991 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3992 Py_INCREF(obj);
3993 PyList_SET_ITEM(inner_list, j, obj);
3994 }
3995 PyList_SET_ITEM(result, i, inner_list);
3996 } else if (!op_def->output_arg(i).type_list_attr().empty()) {
3997 int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
3998 PyObject* inner_list = PyList_New(list_length);
3999 for (int j = 0; j < list_length; j++) {
4000 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
4001 Py_INCREF(obj);
4002 PyList_SET_ITEM(inner_list, j, obj);
4003 }
4004 PyList_SET_ITEM(result, i, inner_list);
4005 } else {
4006 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
4007 Py_INCREF(obj);
4008 PyList_SET_ITEM(result, i, obj);
4009 }
4010 }
4011 return result;
4012 }
4013
4014 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
4015 PyObject* attrs, PyObject* results,
4016 PyObject* forward_pass_name_scope) {
4017 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
4018 Py_RETURN_NONE;
4019 }
4020
4021 return RecordGradient(op_name, inputs, attrs, results,
4022 forward_pass_name_scope);
4023 }
4024
4025 // A method prints incoming messages directly to Python's
4026 // stdout using Python's C API. This is necessary in Jupyter notebooks
4027 // and colabs where messages to the C stdout don't go to the notebook
4028 // cell outputs, but calls to Python's stdout do.
4029 void PrintToPythonStdout(const char* msg) {
4030 if (Py_IsInitialized()) {
4031 PyGILState_STATE py_threadstate;
4032 py_threadstate = PyGILState_Ensure();
4033
4034 string string_msg = msg;
4035 // PySys_WriteStdout truncates strings over 1000 bytes, so
4036 // we write the message in chunks small enough to not be truncated.
4037 int CHUNK_SIZE = 900;
4038 auto len = string_msg.length();
4039 for (int i = 0; i < len; i += CHUNK_SIZE) {
4040 PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
4041 }
4042
4043 // Force flushing to make sure print newlines aren't interleaved in
4044 // some colab environments
4045 PyRun_SimpleString("import sys; sys.stdout.flush()");
4046
4047 PyGILState_Release(py_threadstate);
4048 }
4049 }
4050
4051 // Register PrintToPythonStdout as a log listener, to allow
4052 // printing in colabs and jupyter notebooks to work.
4053 void TFE_Py_EnableInteractivePythonLogging() {
4054 static bool enabled_interactive_logging = false;
4055 if (!enabled_interactive_logging) {
4056 enabled_interactive_logging = true;
4057 TF_RegisterLogListener(PrintToPythonStdout);
4058 }
4059 }
4060
4061 namespace {
4062 // TODO(mdan): Clean this. Maybe by decoupling context lifetime from Python GC?
4063 // Weak reference to the Python Context (see tensorflow/python/eager/context.py)
4064 // object currently active. This object is opaque and wrapped inside a Python
4065 // Capsule. However, the EagerContext object it holds is tracked by the
4066 // global_c_eager_context object.
4067 // Also see common_runtime/eager/context.cc.
4068 PyObject* global_py_eager_context = nullptr;
4069 } // namespace
4070
4071 PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
4072 Py_XDECREF(global_py_eager_context);
4073 global_py_eager_context = PyWeakref_NewRef(py_context, nullptr);
4074 if (global_py_eager_context == nullptr) {
4075 return nullptr;
4076 }
4077 Py_RETURN_NONE;
4078 }
4079
4080 PyObject* GetPyEagerContext() {
4081 if (global_py_eager_context == nullptr) {
4082 PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
4083 return nullptr;
4084 }
4085 PyObject* py_context = PyWeakref_GET_OBJECT(global_py_eager_context);
4086 if (py_context == Py_None) {
4087 PyErr_SetString(PyExc_RuntimeError,
4088 "Python eager context has been destroyed");
4089 return nullptr;
4090 }
4091 Py_INCREF(py_context);
4092 return py_context;
4093 }
4094
4095 namespace {
4096
4097 // Default values for thread_local_data fields.
4098 struct EagerContextThreadLocalDataDefaults {
4099 tensorflow::Safe_PyObjectPtr is_eager;
4100 tensorflow::Safe_PyObjectPtr device_spec;
4101 };
4102
4103 // Maps each py_eager_context object to its thread_local_data.
4104 //
4105 // Note: we need to use the python Context object as the key here (and not
4106 // its handle object), because the handle object isn't created until the
4107 // context is initialized; but thread_local_data is potentially accessed
4108 // before then.
4109 using EagerContextThreadLocalDataMap = absl::flat_hash_map<
4110 PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
4111 thread_local EagerContextThreadLocalDataMap*
4112 eager_context_thread_local_data_map = nullptr;
4113
4114 // Maps each py_eager_context object to default values.
4115 using EagerContextThreadLocalDataDefaultsMap =
4116 absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
4117 EagerContextThreadLocalDataDefaultsMap*
4118 eager_context_thread_local_data_defaults = nullptr;
4119
4120 } // namespace
4121
4122 namespace tensorflow {
4123
4124 void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
4125 PyObject* is_eager,
4126 PyObject* device_spec) {
4127 DCheckPyGilState();
4128 if (eager_context_thread_local_data_defaults == nullptr) {
4129 absl::LeakCheckDisabler disabler;
4130 eager_context_thread_local_data_defaults =
4131 new EagerContextThreadLocalDataDefaultsMap();
4132 }
4133 if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
4134 PyErr_SetString(PyExc_AssertionError,
4135 "MakeEagerContextThreadLocalData may not be called "
4136 "twice on the same eager Context object.");
4137 }
4138
4139 auto& defaults =
4140 (*eager_context_thread_local_data_defaults)[py_eager_context];
4141 Py_INCREF(is_eager);
4142 defaults.is_eager.reset(is_eager);
4143 Py_INCREF(device_spec);
4144 defaults.device_spec.reset(device_spec);
4145 }
4146
4147 EagerContextThreadLocalData* GetEagerContextThreadLocalData(
4148 PyObject* py_eager_context) {
4149 if (eager_context_thread_local_data_defaults == nullptr) {
4150 PyErr_SetString(PyExc_AssertionError,
4151 "MakeEagerContextThreadLocalData must be called "
4152 "before GetEagerContextThreadLocalData.");
4153 return nullptr;
4154 }
4155 auto defaults =
4156 eager_context_thread_local_data_defaults->find(py_eager_context);
4157 if (defaults == eager_context_thread_local_data_defaults->end()) {
4158 PyErr_SetString(PyExc_AssertionError,
4159 "MakeEagerContextThreadLocalData must be called "
4160 "before GetEagerContextThreadLocalData.");
4161 return nullptr;
4162 }
4163
4164 if (eager_context_thread_local_data_map == nullptr) {
4165 absl::LeakCheckDisabler disabler;
4166 eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
4167 }
4168 auto& thread_local_data =
4169 (*eager_context_thread_local_data_map)[py_eager_context];
4170
4171 if (!thread_local_data) {
4172 thread_local_data.reset(new EagerContextThreadLocalData());
4173
4174 Safe_PyObjectPtr is_eager(
4175 PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr));
4176 if (!is_eager) return nullptr;
4177 thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
4178
4179 #if PY_MAJOR_VERSION >= 3
4180 PyObject* scope_name = PyUnicode_FromString("");
4181 #else
4182 PyObject* scope_name = PyString_FromString("");
4183 #endif
4184 thread_local_data->scope_name.reset(scope_name);
4185
4186 #if PY_MAJOR_VERSION >= 3
4187 PyObject* device_name = PyUnicode_FromString("");
4188 #else
4189 PyObject* device_name = PyString_FromString("");
4190 #endif
4191 thread_local_data->device_name.reset(device_name);
4192
4193 Py_INCREF(defaults->second.device_spec.get());
4194 thread_local_data->device_spec.reset(defaults->second.device_spec.get());
4195
4196 Py_INCREF(Py_None);
4197 thread_local_data->function_call_options.reset(Py_None);
4198
4199 Py_INCREF(Py_None);
4200 thread_local_data->executor.reset(Py_None);
4201
4202 thread_local_data->op_callbacks.reset(PyList_New(0));
4203 }
4204 return thread_local_data.get();
4205 }
4206
4207 void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
4208 DCheckPyGilState();
4209 if (eager_context_thread_local_data_defaults) {
4210 eager_context_thread_local_data_defaults->erase(py_eager_context);
4211 }
4212 if (eager_context_thread_local_data_map) {
4213 eager_context_thread_local_data_map->erase(py_eager_context);
4214 }
4215 }
4216
4217 } // namespace tensorflow
4218