1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <atomic>
17 #include <cstring>
18 #include <unordered_map>
19
20 #include "absl/debugging/leak_check.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/c/c_api.h"
24 #include "tensorflow/c/c_api_internal.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/c/eager/tape.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_op_internal.h"
30 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
31 #include "tensorflow/c/tf_status.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/gtl/compactptrset.h"
36 #include "tensorflow/core/lib/gtl/flatmap.h"
37 #include "tensorflow/core/lib/gtl/flatset.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/lib/strings/stringprintf.h"
40 #include "tensorflow/core/platform/casts.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/status.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/util/managed_stack_trace.h"
47 #include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
48 #include "tensorflow/python/eager/pywrap_tensor.h"
49 #include "tensorflow/python/eager/pywrap_tfe.h"
50 #include "tensorflow/python/lib/core/py_util.h"
51 #include "tensorflow/python/lib/core/safe_ptr.h"
52 #include "tensorflow/python/util/stack_trace.h"
53 #include "tensorflow/python/util/util.h"
54
55 using tensorflow::Status;
56 using tensorflow::string;
57 using tensorflow::strings::Printf;
58
59 namespace {
60 // NOTE: Items are retrieved from and returned to these unique_ptrs, and they
61 // act as arenas. This is important if the same thread requests 2 items without
62 // releasing one.
63 // The following sequence of events on the same thread will still succeed:
64 // - GetOp <- Returns existing.
65 // - GetOp <- Allocates and returns a new pointer.
66 // - ReleaseOp <- Sets the item in the unique_ptr.
67 // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
68 // This occurs when a PyFunc kernel is run. This behavior makes it safe in that
69 // case, as well as the case where python decides to reuse the underlying
70 // C++ thread in 2 python threads case.
71 struct OpDeleter {
operator ()__anonb044ee210111::OpDeleter72 void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
73 };
74 thread_local std::unordered_map<TFE_Context*,
75 std::unique_ptr<TFE_Op, OpDeleter>>
76 thread_local_eager_operation_map; // NOLINT
77 thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
78 nullptr;
79
ReleaseThreadLocalOp(TFE_Context * ctx)80 std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
81 auto it = thread_local_eager_operation_map.find(ctx);
82 if (it == thread_local_eager_operation_map.end()) {
83 return nullptr;
84 }
85 return std::move(it->second);
86 }
87
GetOp(TFE_Context * ctx,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)88 TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
89 const char* raw_device_name, TF_Status* status) {
90 auto op = ReleaseThreadLocalOp(ctx);
91 if (!op) {
92 op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
93 }
94 status->status =
95 tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
96 if (!status->status.ok()) {
97 op.reset();
98 }
99 return op.release();
100 }
101
ReturnOp(TFE_Context * ctx,TFE_Op * op)102 void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
103 if (op) {
104 tensorflow::unwrap(op)->Clear();
105 thread_local_eager_operation_map[ctx].reset(op);
106 }
107 }
108
ReleaseThreadLocalStatus()109 TF_Status* ReleaseThreadLocalStatus() {
110 if (thread_local_tf_status == nullptr) {
111 return nullptr;
112 }
113 return thread_local_tf_status.release();
114 }
115
116 struct InputInfo {
InputInfo__anonb044ee210111::InputInfo117 InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
118
119 int i;
120 bool is_list = false;
121 };
122
123 // Takes in output gradients, returns input gradients.
124 typedef std::function<PyObject*(PyObject*,
125 const std::vector<tensorflow::int64>&)>
126 PyBackwardFunction;
127
128 using AttrToInputsMap =
129 tensorflow::gtl::FlatMap<string,
130 tensorflow::gtl::InlinedVector<InputInfo, 4>>;
131
GetAllAttrToInputsMaps()132 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
133 static auto* all_attr_to_input_maps =
134 new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
135 return all_attr_to_input_maps;
136 }
137
138 // This function doesn't use a lock, since we depend on the GIL directly.
GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef & op_def)139 AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
140 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
141 DCHECK(PyGILState_Check())
142 << "This function needs to hold the GIL when called.";
143 #endif
144 auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
145 auto* output =
146 tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
147 if (output != nullptr) {
148 return output;
149 }
150
151 std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
152
153 // Store a list of InputIndex -> List of corresponding inputs.
154 for (int i = 0; i < op_def.input_arg_size(); i++) {
155 if (!op_def.input_arg(i).type_attr().empty()) {
156 auto it = m->find(op_def.input_arg(i).type_attr());
157 if (it == m->end()) {
158 it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
159 }
160 it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
161 }
162 }
163
164 auto* retval = m.get();
165 (*all_attr_to_input_maps)[op_def.name()] = m.release();
166
167 return retval;
168 }
169
170 // This function doesn't use a lock, since we depend on the GIL directly.
171 tensorflow::gtl::FlatMap<
172 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
GetAllAttrToDefaultsMaps()173 GetAllAttrToDefaultsMaps() {
174 static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
175 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
176 return all_attr_to_defaults_maps;
177 }
178
179 tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef & op_def)180 GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
181 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
182 DCHECK(PyGILState_Check())
183 << "This function needs to hold the GIL when called.";
184 #endif
185 auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
186 auto* output =
187 tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
188 if (output != nullptr) {
189 return output;
190 }
191
192 auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
193
194 for (const auto& attr : op_def.attr()) {
195 if (attr.type() == "type" && attr.has_default_value()) {
196 new_map->insert({attr.name(), attr.default_value().type()});
197 }
198 }
199
200 (*all_attr_to_defaults_maps)[op_def.name()] = new_map;
201
202 return new_map;
203 }
204
205 struct FastPathOpExecInfo {
206 TFE_Context* ctx;
207 const char* device_name;
208
209 bool run_callbacks;
210 bool run_post_exec_callbacks;
211 bool run_gradient_callback;
212
213 // The op name of the main op being executed.
214 PyObject* name;
215 // The op type name of the main op being executed.
216 PyObject* op_name;
217 PyObject* callbacks;
218
219 // All the args passed into the FastPathOpExecInfo.
220 PyObject* args;
221
222 // DTypes can come from another input that has the same attr. So build that
223 // map.
224 const AttrToInputsMap* attr_to_inputs_map;
225 const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
226 tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
227 };
228
229 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
230 bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
231 type* value) { \
232 if (check_fn(py_value)) { \
233 *value = static_cast<type>(parse_fn(py_value)); \
234 return true; \
235 } else { \
236 TF_SetStatus(status, TF_INVALID_ARGUMENT, \
237 tensorflow::strings::StrCat( \
238 "Expecting " #type " value for attr ", key, ", got ", \
239 py_value->ob_type->tp_name) \
240 .c_str()); \
241 return false; \
242 } \
243 }
244
245 #if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue,int,PyLong_Check,PyLong_AsLong)246 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
247 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
248 #else
249 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
250 #endif
251 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
252 #undef PARSE_VALUE
253
254 #if PY_MAJOR_VERSION < 3
255 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
256 int64_t* value) {
257 if (PyInt_Check(py_value)) {
258 *value = static_cast<int64_t>(PyInt_AsLong(py_value));
259 return true;
260 } else if (PyLong_Check(py_value)) {
261 *value = static_cast<int64_t>(PyLong_AsLong(py_value));
262 return true;
263 }
264 TF_SetStatus(
265 status, TF_INVALID_ARGUMENT,
266 tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
267 ", got ", py_value->ob_type->tp_name)
268 .c_str());
269 return false;
270 }
271 #endif
272
TensorShapeNumDims(PyObject * value)273 Py_ssize_t TensorShapeNumDims(PyObject* value) {
274 const auto size = PySequence_Size(value);
275 if (size == -1) {
276 // TensorShape.__len__ raises an error in the scenario where the shape is an
277 // unknown, which needs to be cleared.
278 // TODO(nareshmodi): ensure that this is actually a TensorShape.
279 PyErr_Clear();
280 }
281 return size;
282 }
283
IsInteger(PyObject * py_value)284 bool IsInteger(PyObject* py_value) {
285 #if PY_MAJOR_VERSION >= 3
286 return PyLong_Check(py_value);
287 #else
288 return PyInt_Check(py_value) || PyLong_Check(py_value);
289 #endif
290 }
291
292 // This function considers a Dimension._value of None to be valid, and sets the
293 // value to be -1 in that case.
ParseDimensionValue(const string & key,PyObject * py_value,TF_Status * status,int64_t * value)294 bool ParseDimensionValue(const string& key, PyObject* py_value,
295 TF_Status* status, int64_t* value) {
296 if (IsInteger(py_value)) {
297 return ParseInt64Value(key, py_value, status, value);
298 }
299
300 tensorflow::Safe_PyObjectPtr dimension_value(
301 PyObject_GetAttrString(py_value, "_value"));
302 if (dimension_value == nullptr) {
303 PyErr_Clear();
304 TF_SetStatus(
305 status, TF_INVALID_ARGUMENT,
306 tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
307 ", got ", py_value->ob_type->tp_name)
308 .c_str());
309 return false;
310 }
311
312 if (dimension_value.get() == Py_None) {
313 *value = -1;
314 return true;
315 }
316
317 return ParseInt64Value(key, dimension_value.get(), status, value);
318 }
319
ParseStringValue(const string & key,PyObject * py_value,TF_Status * status,tensorflow::StringPiece * value)320 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
321 tensorflow::StringPiece* value) {
322 if (PyBytes_Check(py_value)) {
323 Py_ssize_t size = 0;
324 char* buf = nullptr;
325 if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
326 *value = tensorflow::StringPiece(buf, size);
327 return true;
328 }
329 #if PY_MAJOR_VERSION >= 3
330 if (PyUnicode_Check(py_value)) {
331 Py_ssize_t size = 0;
332 const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
333 if (buf == nullptr) return false;
334 *value = tensorflow::StringPiece(buf, size);
335 return true;
336 }
337 #endif
338 TF_SetStatus(
339 status, TF_INVALID_ARGUMENT,
340 tensorflow::strings::StrCat("Expecting a string value for attr ", key,
341 ", got ", py_value->ob_type->tp_name)
342 .c_str());
343 return false;
344 }
345
ParseBoolValue(const string & key,PyObject * py_value,TF_Status * status,unsigned char * value)346 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
347 unsigned char* value) {
348 *value = PyObject_IsTrue(py_value);
349 return true;
350 }
351
352 // The passed in py_value is expected to be an object of the python type
353 // dtypes.DType or an int.
ParseTypeValue(const string & key,PyObject * py_value,TF_Status * status,int * value)354 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
355 int* value) {
356 if (IsInteger(py_value)) {
357 return ParseIntValue(key, py_value, status, value);
358 }
359
360 tensorflow::Safe_PyObjectPtr py_type_enum(
361 PyObject_GetAttrString(py_value, "_type_enum"));
362 if (py_type_enum == nullptr) {
363 PyErr_Clear();
364 TF_SetStatus(
365 status, TF_INVALID_ARGUMENT,
366 tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
367 ", got ", py_value->ob_type->tp_name)
368 .c_str());
369 return false;
370 }
371
372 return ParseIntValue(key, py_type_enum.get(), status, value);
373 }
374
SetOpAttrList(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_list,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)375 bool SetOpAttrList(
376 TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
377 TF_AttrType type,
378 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
379 TF_Status* status) {
380 if (!PySequence_Check(py_list)) {
381 TF_SetStatus(
382 status, TF_INVALID_ARGUMENT,
383 tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
384 ", got ", py_list->ob_type->tp_name)
385 .c_str());
386 return false;
387 }
388 const int num_values = PySequence_Size(py_list);
389 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
390
391 #define PARSE_LIST(c_type, parse_fn) \
392 std::unique_ptr<c_type[]> values(new c_type[num_values]); \
393 for (int i = 0; i < num_values; ++i) { \
394 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
395 if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
396 }
397
398 if (type == TF_ATTR_STRING) {
399 std::unique_ptr<const void*[]> values(new const void*[num_values]);
400 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
401 for (int i = 0; i < num_values; ++i) {
402 tensorflow::StringPiece value;
403 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
404 if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
405 values[i] = value.data();
406 lengths[i] = value.size();
407 }
408 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
409 } else if (type == TF_ATTR_INT) {
410 PARSE_LIST(int64_t, ParseInt64Value);
411 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
412 } else if (type == TF_ATTR_FLOAT) {
413 PARSE_LIST(float, ParseFloatValue);
414 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
415 } else if (type == TF_ATTR_BOOL) {
416 PARSE_LIST(unsigned char, ParseBoolValue);
417 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
418 } else if (type == TF_ATTR_TYPE) {
419 PARSE_LIST(int, ParseTypeValue);
420 TFE_OpSetAttrTypeList(op, key,
421 reinterpret_cast<const TF_DataType*>(values.get()),
422 num_values);
423 } else if (type == TF_ATTR_SHAPE) {
424 // Make one pass through the input counting the total number of
425 // dims across all the input lists.
426 int total_dims = 0;
427 for (int i = 0; i < num_values; ++i) {
428 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
429 if (py_value.get() != Py_None) {
430 if (!PySequence_Check(py_value.get())) {
431 TF_SetStatus(
432 status, TF_INVALID_ARGUMENT,
433 tensorflow::strings::StrCat(
434 "Expecting None or sequence value for element", i,
435 " of attr ", key, ", got ", py_value->ob_type->tp_name)
436 .c_str());
437 return false;
438 }
439 const auto size = TensorShapeNumDims(py_value.get());
440 if (size >= 0) {
441 total_dims += size;
442 }
443 }
444 }
445 // Allocate a buffer that can fit all of the dims together.
446 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
447 // Copy the input dims into the buffer and set dims to point to
448 // the start of each list's dims.
449 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
450 std::unique_ptr<int[]> num_dims(new int[num_values]);
451 int64_t* offset = buffer.get();
452 for (int i = 0; i < num_values; ++i) {
453 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
454 if (py_value.get() == Py_None) {
455 dims[i] = nullptr;
456 num_dims[i] = -1;
457 } else {
458 const auto size = TensorShapeNumDims(py_value.get());
459 if (size == -1) {
460 dims[i] = nullptr;
461 num_dims[i] = -1;
462 continue;
463 }
464 dims[i] = offset;
465 num_dims[i] = size;
466 for (int j = 0; j < size; ++j) {
467 tensorflow::Safe_PyObjectPtr inner_py_value(
468 PySequence_ITEM(py_value.get(), j));
469 if (inner_py_value.get() == Py_None) {
470 *offset = -1;
471 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
472 offset)) {
473 return false;
474 }
475 ++offset;
476 }
477 }
478 }
479 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
480 status);
481 if (!status->status.ok()) return false;
482 } else if (type == TF_ATTR_FUNC) {
483 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
484 for (int i = 0; i < num_values; ++i) {
485 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
486 // Allow:
487 // (1) String function name, OR
488 // (2) A Python object with a .name attribute
489 // (A crude test for being a
490 // tensorflow.python.framework.function._DefinedFunction)
491 // (which is what the various "defun" or "Defun" decorators do).
492 // And in the future also allow an object that can encapsulate
493 // the function name and its attribute values.
494 tensorflow::StringPiece func_name;
495 if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
496 PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
497 if (name_attr == nullptr ||
498 !ParseStringValue(key, name_attr, status, &func_name)) {
499 TF_SetStatus(
500 status, TF_INVALID_ARGUMENT,
501 tensorflow::strings::StrCat(
502 "unable to set function value attribute from a ",
503 py_value.get()->ob_type->tp_name,
504 " object. If you think this is an error, please file an "
505 "issue at "
506 "https://github.com/tensorflow/tensorflow/issues/new")
507 .c_str());
508 return false;
509 }
510 }
511 funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
512 if (!status->status.ok()) return false;
513 }
514 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
515 if (!status->status.ok()) return false;
516 } else {
517 TF_SetStatus(status, TF_UNIMPLEMENTED,
518 tensorflow::strings::StrCat("Attr ", key,
519 " has unhandled list type ", type)
520 .c_str());
521 return false;
522 }
523 #undef PARSE_LIST
524 return true;
525 }
526
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)527 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
528 TF_Status* status) {
529 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
530 for (const auto& attr : func.attr()) {
531 if (!status->status.ok()) return nullptr;
532 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
533 if (!status->status.ok()) return nullptr;
534 }
535 return func_op;
536 }
537
SetOpAttrListDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * key,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)538 void SetOpAttrListDefault(
539 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
540 const char* key, TF_AttrType type,
541 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
542 TF_Status* status) {
543 if (type == TF_ATTR_STRING) {
544 int num_values = attr.default_value().list().s_size();
545 std::unique_ptr<const void*[]> values(new const void*[num_values]);
546 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
547 (*attr_list_sizes)[key] = num_values;
548 for (int i = 0; i < num_values; i++) {
549 const string& v = attr.default_value().list().s(i);
550 values[i] = v.data();
551 lengths[i] = v.size();
552 }
553 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
554 } else if (type == TF_ATTR_INT) {
555 int num_values = attr.default_value().list().i_size();
556 std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
557 (*attr_list_sizes)[key] = num_values;
558 for (int i = 0; i < num_values; i++) {
559 values[i] = attr.default_value().list().i(i);
560 }
561 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
562 } else if (type == TF_ATTR_FLOAT) {
563 int num_values = attr.default_value().list().f_size();
564 std::unique_ptr<float[]> values(new float[num_values]);
565 (*attr_list_sizes)[key] = num_values;
566 for (int i = 0; i < num_values; i++) {
567 values[i] = attr.default_value().list().f(i);
568 }
569 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
570 } else if (type == TF_ATTR_BOOL) {
571 int num_values = attr.default_value().list().b_size();
572 std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
573 (*attr_list_sizes)[key] = num_values;
574 for (int i = 0; i < num_values; i++) {
575 values[i] = attr.default_value().list().b(i);
576 }
577 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
578 } else if (type == TF_ATTR_TYPE) {
579 int num_values = attr.default_value().list().type_size();
580 std::unique_ptr<int[]> values(new int[num_values]);
581 (*attr_list_sizes)[key] = num_values;
582 for (int i = 0; i < num_values; i++) {
583 values[i] = attr.default_value().list().type(i);
584 }
585 TFE_OpSetAttrTypeList(op, key,
586 reinterpret_cast<const TF_DataType*>(values.get()),
587 attr.default_value().list().type_size());
588 } else if (type == TF_ATTR_SHAPE) {
589 int num_values = attr.default_value().list().shape_size();
590 (*attr_list_sizes)[key] = num_values;
591 int total_dims = 0;
592 for (int i = 0; i < num_values; ++i) {
593 if (!attr.default_value().list().shape(i).unknown_rank()) {
594 total_dims += attr.default_value().list().shape(i).dim_size();
595 }
596 }
597 // Allocate a buffer that can fit all of the dims together.
598 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
599 // Copy the input dims into the buffer and set dims to point to
600 // the start of each list's dims.
601 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
602 std::unique_ptr<int[]> num_dims(new int[num_values]);
603 int64_t* offset = buffer.get();
604 for (int i = 0; i < num_values; ++i) {
605 const auto& shape = attr.default_value().list().shape(i);
606 if (shape.unknown_rank()) {
607 dims[i] = nullptr;
608 num_dims[i] = -1;
609 } else {
610 for (int j = 0; j < shape.dim_size(); j++) {
611 *offset = shape.dim(j).size();
612 ++offset;
613 }
614 }
615 }
616 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
617 status);
618 } else if (type == TF_ATTR_FUNC) {
619 int num_values = attr.default_value().list().func_size();
620 (*attr_list_sizes)[key] = num_values;
621 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
622 for (int i = 0; i < num_values; i++) {
623 funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
624 }
625 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
626 } else {
627 TF_SetStatus(status, TF_UNIMPLEMENTED,
628 "Lists of tensors are not yet implemented for default valued "
629 "attributes for an operation.");
630 }
631 }
632
SetOpAttrScalar(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_value,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)633 bool SetOpAttrScalar(
634 TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
635 TF_AttrType type,
636 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
637 TF_Status* status) {
638 if (type == TF_ATTR_STRING) {
639 tensorflow::StringPiece value;
640 if (!ParseStringValue(key, py_value, status, &value)) return false;
641 TFE_OpSetAttrString(op, key, value.data(), value.size());
642 } else if (type == TF_ATTR_INT) {
643 int64_t value;
644 if (!ParseInt64Value(key, py_value, status, &value)) return false;
645 TFE_OpSetAttrInt(op, key, value);
646 // attr_list_sizes is set for all int attributes (since at this point we are
647 // not aware if that attribute might be used to calculate the size of an
648 // output list or not).
649 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
650 } else if (type == TF_ATTR_FLOAT) {
651 float value;
652 if (!ParseFloatValue(key, py_value, status, &value)) return false;
653 TFE_OpSetAttrFloat(op, key, value);
654 } else if (type == TF_ATTR_BOOL) {
655 unsigned char value;
656 if (!ParseBoolValue(key, py_value, status, &value)) return false;
657 TFE_OpSetAttrBool(op, key, value);
658 } else if (type == TF_ATTR_TYPE) {
659 int value;
660 if (!ParseTypeValue(key, py_value, status, &value)) return false;
661 TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
662 } else if (type == TF_ATTR_SHAPE) {
663 if (py_value == Py_None) {
664 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
665 } else {
666 if (!PySequence_Check(py_value)) {
667 TF_SetStatus(status, TF_INVALID_ARGUMENT,
668 tensorflow::strings::StrCat(
669 "Expecting None or sequence value for attr", key,
670 ", got ", py_value->ob_type->tp_name)
671 .c_str());
672 return false;
673 }
674 const auto num_dims = TensorShapeNumDims(py_value);
675 if (num_dims == -1) {
676 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
677 return true;
678 }
679 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
680 for (int i = 0; i < num_dims; ++i) {
681 tensorflow::Safe_PyObjectPtr inner_py_value(
682 PySequence_ITEM(py_value, i));
683 if (inner_py_value.get() == Py_None) {
684 dims[i] = -1;
685 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
686 &dims[i])) {
687 return false;
688 }
689 }
690 TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
691 }
692 if (!status->status.ok()) return false;
693 } else if (type == TF_ATTR_FUNC) {
694 // Allow:
695 // (1) String function name, OR
696 // (2) A Python object with a .name attribute
697 // (A crude test for being a
698 // tensorflow.python.framework.function._DefinedFunction)
699 // (which is what the various "defun" or "Defun" decorators do).
700 // And in the future also allow an object that can encapsulate
701 // the function name and its attribute values.
702 tensorflow::StringPiece func_name;
703 if (!ParseStringValue(key, py_value, status, &func_name)) {
704 PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
705 if (name_attr == nullptr ||
706 !ParseStringValue(key, name_attr, status, &func_name)) {
707 TF_SetStatus(
708 status, TF_INVALID_ARGUMENT,
709 tensorflow::strings::StrCat(
710 "unable to set function value attribute from a ",
711 py_value->ob_type->tp_name,
712 " object. If you think this is an error, please file an issue "
713 "at https://github.com/tensorflow/tensorflow/issues/new")
714 .c_str());
715 return false;
716 }
717 }
718 TF_SetStatus(status, TF_OK, "");
719 TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
720 } else {
721 TF_SetStatus(
722 status, TF_UNIMPLEMENTED,
723 tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
724 .c_str());
725 return false;
726 }
727 return true;
728 }
729
SetOpAttrScalarDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)730 void SetOpAttrScalarDefault(
731 TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
732 const char* attr_name,
733 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
734 TF_Status* status) {
735 SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
736 if (default_value.value_case() == tensorflow::AttrValue::kI) {
737 (*attr_list_sizes)[attr_name] = default_value.i();
738 }
739 }
740
741 // start_index is the index at which the Tuple/List attrs will start getting
742 // processed.
SetOpAttrs(TFE_Context * ctx,TFE_Op * op,PyObject * attrs,int start_index,TF_Status * out_status)743 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
744 TF_Status* out_status) {
745 if (attrs == Py_None) return;
746 Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
747 if ((len & 1) != 0) {
748 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
749 "Expecting attrs tuple to have even length.");
750 return;
751 }
752 // Parse attrs
753 for (Py_ssize_t i = 0; i < len; i += 2) {
754 PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
755 PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
756 #if PY_MAJOR_VERSION >= 3
757 const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
758 : PyUnicode_AsUTF8(py_key);
759 #else
760 const char* key = PyBytes_AsString(py_key);
761 #endif
762 unsigned char is_list = 0;
763 const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
764 if (!out_status->status.ok()) return;
765 if (is_list != 0) {
766 if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
767 return;
768 } else {
769 if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
770 return;
771 }
772 }
773 }
774
775 // This function will set the op attrs required. If an attr has the value of
776 // None, then it will read the AttrDef to get the default value and set that
777 // instead. Any failure in this function will simply fall back to the slow
778 // path.
SetOpAttrWithDefaults(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * attr_name,PyObject * attr_value,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)779 void SetOpAttrWithDefaults(
780 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
781 const char* attr_name, PyObject* attr_value,
782 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
783 TF_Status* status) {
784 unsigned char is_list = 0;
785 const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
786 if (!status->status.ok()) return;
787 if (attr_value == Py_None) {
788 if (is_list != 0) {
789 SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
790 status);
791 } else {
792 SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
793 attr_list_sizes, status);
794 }
795 } else {
796 if (is_list != 0) {
797 SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
798 status);
799 } else {
800 SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
801 status);
802 }
803 }
804 }
805
GetPythonObjectFromInt(int num)806 PyObject* GetPythonObjectFromInt(int num) {
807 #if PY_MAJOR_VERSION >= 3
808 return PyLong_FromLong(num);
809 #else
810 return PyInt_FromLong(num);
811 #endif
812 }
813
814 // Python subclass of Exception that is created on not ok Status.
815 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
816 PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
817
818 // Python subclass of Exception that is created to signal fallback.
819 PyObject* fallback_exception_class = nullptr;
820
821 // Python function that returns input gradients given output gradients.
822 PyObject* gradient_function = nullptr;
823
824 // Python function that returns output gradients given input gradients.
825 PyObject* forward_gradient_function = nullptr;
826
827 static std::atomic<int64_t> _uid;
828
829 } // namespace
830
GetStatus()831 TF_Status* GetStatus() {
832 TF_Status* maybe_status = ReleaseThreadLocalStatus();
833 if (maybe_status) {
834 TF_SetStatus(maybe_status, TF_OK, "");
835 return maybe_status;
836 } else {
837 return TF_NewStatus();
838 }
839 }
840
ReturnStatus(TF_Status * status)841 void ReturnStatus(TF_Status* status) {
842 TF_SetStatus(status, TF_OK, "");
843 thread_local_tf_status.reset(status);
844 }
845
TFE_Py_Execute(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_OutputTensorHandles * outputs,TF_Status * out_status)846 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
847 const char* op_name, TFE_InputTensorHandles* inputs,
848 PyObject* attrs, TFE_OutputTensorHandles* outputs,
849 TF_Status* out_status) {
850 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
851 /*cancellation_manager=*/nullptr, outputs,
852 out_status);
853 }
854
TFE_Py_ExecuteCancelable(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_CancellationManager * cancellation_manager,TFE_OutputTensorHandles * outputs,TF_Status * out_status)855 void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
856 const char* op_name,
857 TFE_InputTensorHandles* inputs, PyObject* attrs,
858 TFE_CancellationManager* cancellation_manager,
859 TFE_OutputTensorHandles* outputs,
860 TF_Status* out_status) {
861 tensorflow::profiler::TraceMe activity(
862 "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
863
864 TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
865
866 auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
867 if (!out_status->status.ok()) return;
868
869 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
870 tensorflow::StackTrace::kStackTraceInitialSize));
871
872 for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
873 TFE_OpAddInput(op, inputs->at(i), out_status);
874 }
875 if (cancellation_manager && out_status->status.ok()) {
876 TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
877 }
878 if (out_status->status.ok()) {
879 SetOpAttrs(ctx, op, attrs, 0, out_status);
880 }
881 Py_BEGIN_ALLOW_THREADS;
882
883 int num_outputs = outputs->size();
884
885 if (out_status->status.ok()) {
886 TFE_Execute(op, outputs->data(), &num_outputs, out_status);
887 }
888
889 if (out_status->status.ok()) {
890 outputs->resize(num_outputs);
891 } else {
892 TF_SetStatus(out_status, TF_GetCode(out_status),
893 tensorflow::strings::StrCat(TF_Message(out_status),
894 " [Op:", op_name, "]")
895 .c_str());
896 }
897
898 Py_END_ALLOW_THREADS;
899 }
900
TFE_Py_RegisterExceptionClass(PyObject * e)901 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
902 tensorflow::mutex_lock l(exception_class_mutex);
903 if (exception_class != nullptr) {
904 Py_DECREF(exception_class);
905 }
906 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
907 exception_class = nullptr;
908 PyErr_SetString(PyExc_TypeError,
909 "TFE_Py_RegisterExceptionClass: "
910 "Registered class should be subclass of Exception.");
911 return nullptr;
912 }
913
914 Py_INCREF(e);
915 exception_class = e;
916 Py_RETURN_NONE;
917 }
918
TFE_Py_RegisterFallbackExceptionClass(PyObject * e)919 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
920 if (fallback_exception_class != nullptr) {
921 Py_DECREF(fallback_exception_class);
922 }
923 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
924 fallback_exception_class = nullptr;
925 PyErr_SetString(PyExc_TypeError,
926 "TFE_Py_RegisterFallbackExceptionClass: "
927 "Registered class should be subclass of Exception.");
928 return nullptr;
929 } else {
930 Py_INCREF(e);
931 fallback_exception_class = e;
932 Py_RETURN_NONE;
933 }
934 }
935
TFE_Py_RegisterGradientFunction(PyObject * e)936 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
937 if (gradient_function != nullptr) {
938 Py_DECREF(gradient_function);
939 }
940 if (!PyCallable_Check(e)) {
941 gradient_function = nullptr;
942 PyErr_SetString(PyExc_TypeError,
943 "TFE_Py_RegisterGradientFunction: "
944 "Registered object should be function.");
945 return nullptr;
946 } else {
947 Py_INCREF(e);
948 gradient_function = e;
949 Py_RETURN_NONE;
950 }
951 }
952
TFE_Py_RegisterJVPFunction(PyObject * e)953 PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
954 if (forward_gradient_function != nullptr) {
955 Py_DECREF(forward_gradient_function);
956 }
957 if (!PyCallable_Check(e)) {
958 forward_gradient_function = nullptr;
959 PyErr_SetString(PyExc_TypeError,
960 "TFE_Py_RegisterJVPFunction: "
961 "Registered object should be function.");
962 return nullptr;
963 } else {
964 Py_INCREF(e);
965 forward_gradient_function = e;
966 Py_RETURN_NONE;
967 }
968 }
969
RaiseFallbackException(const char * message)970 void RaiseFallbackException(const char* message) {
971 if (fallback_exception_class != nullptr) {
972 PyErr_SetString(fallback_exception_class, message);
973 return;
974 }
975
976 PyErr_SetString(
977 PyExc_RuntimeError,
978 tensorflow::strings::StrCat(
979 "Fallback exception type not set, attempting to fallback due to ",
980 message)
981 .data());
982 }
983
984 // Format and return `status`' error message with the attached stack trace if
985 // available. `status` must have an error.
FormatErrorStatusStackTrace(const tensorflow::Status & status)986 std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
987 tensorflow::DCheckPyGilState();
988 DCHECK(!status.ok());
989
990 if (status.stack_trace().empty()) return status.error_message();
991
992 const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace();
993
994 PyObject* linecache = PyImport_ImportModule("linecache");
995 PyObject* getline =
996 PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
997 DCHECK(getline);
998
999 std::ostringstream result;
1000 result << "Exception originated from\n\n";
1001
1002 for (const tensorflow::StackFrame& stack_frame : stack_trace) {
1003 PyObject* line_str_obj = PyObject_CallFunction(
1004 getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
1005 stack_frame.line_number);
1006 tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
1007 tensorflow::str_util::RemoveWhitespaceContext(&line_str);
1008 result << " File \"" << stack_frame.file_name << "\", line "
1009 << stack_frame.line_number << ", in " << stack_frame.function_name
1010 << '\n';
1011
1012 if (!line_str.empty()) result << " " << line_str << '\n';
1013 Py_XDECREF(line_str_obj);
1014 }
1015
1016 Py_DecRef(getline);
1017 Py_DecRef(linecache);
1018
1019 result << '\n' << status.error_message();
1020 return result.str();
1021 }
1022
MaybeRaiseExceptionFromTFStatus(TF_Status * status,PyObject * exception)1023 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
1024 if (status->status.ok()) return 0;
1025 const char* msg = TF_Message(status);
1026 if (exception == nullptr) {
1027 tensorflow::mutex_lock l(exception_class_mutex);
1028 if (exception_class != nullptr) {
1029 tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1030 "si", FormatErrorStatusStackTrace(status->status).c_str(),
1031 TF_GetCode(status)));
1032 if (PyErr_Occurred()) {
1033 // NOTE: This hides the actual error (i.e. the reason `status` was not
1034 // TF_OK), but there is nothing we can do at this point since we can't
1035 // generate a reasonable error from the status.
1036 // Consider adding a message explaining this.
1037 return -1;
1038 }
1039 PyErr_SetObject(exception_class, val.get());
1040 return -1;
1041 } else {
1042 exception = PyExc_RuntimeError;
1043 }
1044 }
1045 // May be update already set exception.
1046 PyErr_SetString(exception, msg);
1047 return -1;
1048 }
1049
MaybeRaiseExceptionFromStatus(const tensorflow::Status & status,PyObject * exception)1050 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
1051 PyObject* exception) {
1052 if (status.ok()) return 0;
1053 const char* msg = status.error_message().c_str();
1054 if (exception == nullptr) {
1055 tensorflow::mutex_lock l(exception_class_mutex);
1056 if (exception_class != nullptr) {
1057 tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1058 "si", FormatErrorStatusStackTrace(status).c_str(), status.code()));
1059 PyErr_SetObject(exception_class, val.get());
1060 return -1;
1061 } else {
1062 exception = PyExc_RuntimeError;
1063 }
1064 }
1065 // May be update already set exception.
1066 PyErr_SetString(exception, msg);
1067 return -1;
1068 }
1069
TFE_GetPythonString(PyObject * o)1070 const char* TFE_GetPythonString(PyObject* o) {
1071 #if PY_MAJOR_VERSION >= 3
1072 if (PyBytes_Check(o)) {
1073 return PyBytes_AsString(o);
1074 } else {
1075 return PyUnicode_AsUTF8(o);
1076 }
1077 #else
1078 return PyBytes_AsString(o);
1079 #endif
1080 }
1081
get_uid()1082 int64_t get_uid() { return _uid++; }
1083
TFE_Py_UID()1084 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
1085
TFE_DeleteContextCapsule(PyObject * context)1086 void TFE_DeleteContextCapsule(PyObject* context) {
1087 TFE_Context* ctx =
1088 reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
1089 auto op = ReleaseThreadLocalOp(ctx);
1090 op.reset();
1091 TFE_DeleteContext(ctx);
1092 }
1093
MakeInt(PyObject * integer)1094 static tensorflow::int64 MakeInt(PyObject* integer) {
1095 #if PY_MAJOR_VERSION >= 3
1096 return PyLong_AsLong(integer);
1097 #else
1098 return PyInt_AsLong(integer);
1099 #endif
1100 }
1101
FastTensorId(PyObject * tensor)1102 static tensorflow::int64 FastTensorId(PyObject* tensor) {
1103 if (EagerTensor_CheckExact(tensor)) {
1104 return PyEagerTensor_ID(tensor);
1105 }
1106 PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
1107 if (id_field == nullptr) {
1108 return -1;
1109 }
1110 int64_t id = MakeInt(id_field);
1111 Py_DECREF(id_field);
1112 return id;
1113 }
1114
1115 namespace tensorflow {
PyTensor_DataType(PyObject * tensor)1116 DataType PyTensor_DataType(PyObject* tensor) {
1117 if (EagerTensor_CheckExact(tensor)) {
1118 return PyEagerTensor_Dtype(tensor);
1119 } else {
1120 #if PY_MAJOR_VERSION < 3
1121 // Python 2.x:
1122 static PyObject* dtype_attr = PyString_InternFromString("dtype");
1123 static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
1124 #else
1125 // Python 3.x:
1126 static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
1127 static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
1128 #endif
1129 Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
1130 if (!dtype_field) {
1131 return DT_INVALID;
1132 }
1133
1134 Safe_PyObjectPtr enum_field(
1135 PyObject_GetAttr(dtype_field.get(), type_enum_attr));
1136 if (!enum_field) {
1137 return DT_INVALID;
1138 }
1139
1140 return static_cast<DataType>(MakeInt(enum_field.get()));
1141 }
1142 }
1143 } // namespace tensorflow
1144
1145 class PyTapeTensor {
1146 public:
PyTapeTensor(int64_t id,tensorflow::DataType dtype,const tensorflow::TensorShape & shape)1147 PyTapeTensor(int64_t id, tensorflow::DataType dtype,
1148 const tensorflow::TensorShape& shape)
1149 : id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(int64_t id,tensorflow::DataType dtype,PyObject * shape)1150 PyTapeTensor(int64_t id, tensorflow::DataType dtype, PyObject* shape)
1151 : id_(id), dtype_(dtype), shape_(shape) {
1152 Py_INCREF(absl::get<1>(shape_));
1153 }
PyTapeTensor(const PyTapeTensor & other)1154 PyTapeTensor(const PyTapeTensor& other) {
1155 id_ = other.id_;
1156 dtype_ = other.dtype_;
1157 shape_ = other.shape_;
1158 if (shape_.index() == 1) {
1159 Py_INCREF(absl::get<1>(shape_));
1160 }
1161 }
1162
~PyTapeTensor()1163 ~PyTapeTensor() {
1164 if (shape_.index() == 1) {
1165 Py_DECREF(absl::get<1>(shape_));
1166 }
1167 }
1168 PyObject* GetShape() const;
GetPyDType() const1169 PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
GetID() const1170 tensorflow::int64 GetID() const { return id_; }
GetDType() const1171 tensorflow::DataType GetDType() const { return dtype_; }
1172
1173 PyObject* OnesLike() const;
1174 PyObject* ZerosLike() const;
1175
1176 private:
1177 tensorflow::int64 id_;
1178 tensorflow::DataType dtype_;
1179
1180 // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
1181 // PyObject is the tensor itself. This is used to support tf.shape(tensor) for
1182 // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
1183 // tensors.
1184 absl::variant<tensorflow::TensorShape, PyObject*> shape_;
1185 };
1186
1187 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
1188
1189 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
1190 PyTapeTensor> {
1191 public:
PyVSpace(PyObject * py_vspace)1192 explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
1193 Py_INCREF(py_vspace_);
1194 }
1195
Initialize()1196 tensorflow::Status Initialize() {
1197 num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
1198 if (num_elements_ == nullptr) {
1199 return tensorflow::errors::InvalidArgument("invalid vspace");
1200 }
1201 aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
1202 if (aggregate_fn_ == nullptr) {
1203 return tensorflow::errors::InvalidArgument("invalid vspace");
1204 }
1205 zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
1206 if (zeros_fn_ == nullptr) {
1207 return tensorflow::errors::InvalidArgument("invalid vspace");
1208 }
1209 zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
1210 if (zeros_like_fn_ == nullptr) {
1211 return tensorflow::errors::InvalidArgument("invalid vspace");
1212 }
1213 ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
1214 if (ones_fn_ == nullptr) {
1215 return tensorflow::errors::InvalidArgument("invalid vspace");
1216 }
1217 ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
1218 if (ones_like_fn_ == nullptr) {
1219 return tensorflow::errors::InvalidArgument("invalid vspace");
1220 }
1221 graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
1222 if (graph_shape_fn_ == nullptr) {
1223 return tensorflow::errors::InvalidArgument("invalid vspace");
1224 }
1225 return tensorflow::Status::OK();
1226 }
1227
~PyVSpace()1228 ~PyVSpace() override {
1229 Py_XDECREF(num_elements_);
1230 Py_XDECREF(aggregate_fn_);
1231 Py_XDECREF(zeros_fn_);
1232 Py_XDECREF(zeros_like_fn_);
1233 Py_XDECREF(ones_fn_);
1234 Py_XDECREF(ones_like_fn_);
1235 Py_XDECREF(graph_shape_fn_);
1236
1237 Py_DECREF(py_vspace_);
1238 }
1239
NumElements(PyObject * tensor) const1240 tensorflow::int64 NumElements(PyObject* tensor) const final {
1241 if (EagerTensor_CheckExact(tensor)) {
1242 return PyEagerTensor_NumElements(tensor);
1243 }
1244 PyObject* arglist =
1245 Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1246 PyObject* result = PyEval_CallObject(num_elements_, arglist);
1247 Py_DECREF(arglist);
1248 if (result == nullptr) {
1249 // The caller detects whether a python exception has been raised.
1250 return -1;
1251 }
1252 int64_t r = MakeInt(result);
1253 Py_DECREF(result);
1254 return r;
1255 }
1256
AggregateGradients(tensorflow::gtl::ArraySlice<PyObject * > gradient_tensors) const1257 PyObject* AggregateGradients(
1258 tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1259 PyObject* list = PyList_New(gradient_tensors.size());
1260 for (int i = 0; i < gradient_tensors.size(); ++i) {
1261 // Note: stealing a reference to the gradient tensors.
1262 CHECK(gradient_tensors[i] != nullptr);
1263 CHECK(gradient_tensors[i] != Py_None);
1264 PyList_SET_ITEM(list, i,
1265 reinterpret_cast<PyObject*>(gradient_tensors[i]));
1266 }
1267 PyObject* arglist = Py_BuildValue("(O)", list);
1268 CHECK(arglist != nullptr);
1269 PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1270 Py_DECREF(arglist);
1271 Py_DECREF(list);
1272 return result;
1273 }
1274
TensorId(PyObject * tensor) const1275 tensorflow::int64 TensorId(PyObject* tensor) const final {
1276 return FastTensorId(tensor);
1277 }
1278
MarkAsResult(PyObject * gradient) const1279 void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1280
Ones(PyObject * shape,PyObject * dtype) const1281 PyObject* Ones(PyObject* shape, PyObject* dtype) const {
1282 if (PyErr_Occurred()) {
1283 return nullptr;
1284 }
1285 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1286 PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1287 Py_DECREF(arg_list);
1288 return result;
1289 }
1290
OnesLike(PyObject * tensor) const1291 PyObject* OnesLike(PyObject* tensor) const {
1292 if (PyErr_Occurred()) {
1293 return nullptr;
1294 }
1295 return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
1296 }
1297
1298 // Builds a tensor filled with ones with the same shape and dtype as `t`.
BuildOnesLike(const PyTapeTensor & t,PyObject ** result) const1299 Status BuildOnesLike(const PyTapeTensor& t,
1300 PyObject** result) const override {
1301 *result = t.OnesLike();
1302 return Status::OK();
1303 }
1304
Zeros(PyObject * shape,PyObject * dtype) const1305 PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
1306 if (PyErr_Occurred()) {
1307 return nullptr;
1308 }
1309 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1310 PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1311 Py_DECREF(arg_list);
1312 return result;
1313 }
1314
ZerosLike(PyObject * tensor) const1315 PyObject* ZerosLike(PyObject* tensor) const {
1316 if (PyErr_Occurred()) {
1317 return nullptr;
1318 }
1319 return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
1320 }
1321
GraphShape(PyObject * tensor) const1322 PyObject* GraphShape(PyObject* tensor) const {
1323 PyObject* arg_list = Py_BuildValue("(O)", tensor);
1324 PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1325 Py_DECREF(arg_list);
1326 return result;
1327 }
1328
CallBackwardFunction(const string & op_type,PyBackwardFunction * backward_function,const std::vector<tensorflow::int64> & unneeded_gradients,tensorflow::gtl::ArraySlice<PyObject * > output_gradients,absl::Span<PyObject * > result) const1329 tensorflow::Status CallBackwardFunction(
1330 const string& op_type, PyBackwardFunction* backward_function,
1331 const std::vector<tensorflow::int64>& unneeded_gradients,
1332 tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1333 absl::Span<PyObject*> result) const final {
1334 PyObject* grads = PyTuple_New(output_gradients.size());
1335 for (int i = 0; i < output_gradients.size(); ++i) {
1336 if (output_gradients[i] == nullptr) {
1337 Py_INCREF(Py_None);
1338 PyTuple_SET_ITEM(grads, i, Py_None);
1339 } else {
1340 PyTuple_SET_ITEM(grads, i,
1341 reinterpret_cast<PyObject*>(output_gradients[i]));
1342 }
1343 }
1344 PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
1345 Py_DECREF(grads);
1346 if (py_result == nullptr) {
1347 return tensorflow::errors::Internal("gradient function threw exceptions");
1348 }
1349 PyObject* seq =
1350 PySequence_Fast(py_result, "expected a sequence of gradients");
1351 if (seq == nullptr) {
1352 return tensorflow::errors::InvalidArgument(
1353 "gradient function did not return a list");
1354 }
1355 int len = PySequence_Fast_GET_SIZE(seq);
1356 if (len != result.size()) {
1357 return tensorflow::errors::Internal(
1358 "Recorded operation '", op_type,
1359 "' returned too few gradients. Expected ", result.size(),
1360 " but received ", len);
1361 }
1362 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1363 VLOG(1) << "Gradient length is " << len;
1364 for (int i = 0; i < len; ++i) {
1365 PyObject* item = seq_array[i];
1366 if (item == Py_None) {
1367 result[i] = nullptr;
1368 } else {
1369 Py_INCREF(item);
1370 result[i] = item;
1371 }
1372 }
1373 Py_DECREF(seq);
1374 Py_DECREF(py_result);
1375 return tensorflow::Status::OK();
1376 }
1377
DeleteGradient(PyObject * tensor) const1378 void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1379
TapeTensorFromGradient(PyObject * tensor) const1380 PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
1381 return TapeTensorFromTensor(tensor);
1382 }
1383
1384 private:
1385 PyObject* py_vspace_;
1386
1387 PyObject* num_elements_;
1388 PyObject* aggregate_fn_;
1389 PyObject* zeros_fn_;
1390 PyObject* zeros_like_fn_;
1391 PyObject* ones_fn_;
1392 PyObject* ones_like_fn_;
1393 PyObject* graph_shape_fn_;
1394 };
1395 PyVSpace* py_vspace = nullptr;
1396
1397 bool HasAccumulator();
1398
TFE_Py_RegisterVSpace(PyObject * e)1399 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1400 if (py_vspace != nullptr) {
1401 if (HasAccumulator()) {
1402 // Accumulators reference py_vspace, so we can't swap it out while one is
1403 // active. This is unlikely to ever happen.
1404 MaybeRaiseExceptionFromStatus(
1405 tensorflow::errors::Internal(
1406 "Can't change the vspace implementation while a "
1407 "forward accumulator is active."),
1408 nullptr);
1409 }
1410 delete py_vspace;
1411 }
1412
1413 py_vspace = new PyVSpace(e);
1414 auto status = py_vspace->Initialize();
1415 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1416 delete py_vspace;
1417 return nullptr;
1418 }
1419
1420 Py_RETURN_NONE;
1421 }
1422
GetShape() const1423 PyObject* PyTapeTensor::GetShape() const {
1424 if (shape_.index() == 0) {
1425 auto& shape = absl::get<0>(shape_);
1426 PyObject* py_shape = PyTuple_New(shape.dims());
1427 for (int i = 0; i < shape.dims(); ++i) {
1428 PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1429 }
1430
1431 return py_shape;
1432 }
1433
1434 return py_vspace->GraphShape(absl::get<1>(shape_));
1435 }
1436
OnesLike() const1437 PyObject* PyTapeTensor::OnesLike() const {
1438 if (shape_.index() == 1) {
1439 PyObject* tensor = absl::get<1>(shape_);
1440 return py_vspace->OnesLike(tensor);
1441 }
1442 PyObject* py_shape = GetShape();
1443 PyObject* dtype_field = GetPyDType();
1444 PyObject* result = py_vspace->Ones(py_shape, dtype_field);
1445 Py_DECREF(dtype_field);
1446 Py_DECREF(py_shape);
1447 return result;
1448 }
1449
ZerosLike() const1450 PyObject* PyTapeTensor::ZerosLike() const {
1451 if (GetDType() == tensorflow::DT_RESOURCE) {
1452 // Gradient functions for ops which return resource tensors accept
1453 // None. This is the behavior of py_vspace->Zeros, but checking here avoids
1454 // issues with ZerosLike.
1455 Py_RETURN_NONE;
1456 }
1457 if (shape_.index() == 1) {
1458 PyObject* tensor = absl::get<1>(shape_);
1459 return py_vspace->ZerosLike(tensor);
1460 }
1461 PyObject* py_shape = GetShape();
1462 PyObject* dtype_field = GetPyDType();
1463 PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
1464 Py_DECREF(dtype_field);
1465 Py_DECREF(py_shape);
1466 return result;
1467 }
1468
1469 // Keeps track of all variables that have been accessed during execution.
1470 class VariableWatcher {
1471 public:
VariableWatcher()1472 VariableWatcher() {}
1473
~VariableWatcher()1474 ~VariableWatcher() {
1475 for (const IdAndVariable& v : watched_variables_) {
1476 Py_DECREF(v.variable);
1477 }
1478 }
1479
WatchVariable(PyObject * v)1480 tensorflow::int64 WatchVariable(PyObject* v) {
1481 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1482 if (handle == nullptr) {
1483 return -1;
1484 }
1485 int64_t id = FastTensorId(handle.get());
1486
1487 tensorflow::mutex_lock l(watched_variables_mu_);
1488 auto insert_result = watched_variables_.emplace(id, v);
1489
1490 if (insert_result.second) {
1491 // Only increment the reference count if we aren't already watching this
1492 // variable.
1493 Py_INCREF(v);
1494 }
1495
1496 return id;
1497 }
1498
GetVariablesAsPyTuple()1499 PyObject* GetVariablesAsPyTuple() {
1500 tensorflow::mutex_lock l(watched_variables_mu_);
1501 PyObject* result = PyTuple_New(watched_variables_.size());
1502 Py_ssize_t pos = 0;
1503 for (const IdAndVariable& id_and_variable : watched_variables_) {
1504 PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1505 Py_INCREF(id_and_variable.variable);
1506 }
1507 return result;
1508 }
1509
1510 private:
1511 // We store an IdAndVariable in the map since the map needs to be locked
1512 // during insert, but should not call back into python during insert to avoid
1513 // deadlocking with the GIL.
1514 struct IdAndVariable {
1515 tensorflow::int64 id;
1516 PyObject* variable;
1517
IdAndVariableVariableWatcher::IdAndVariable1518 IdAndVariable(int64_t id, PyObject* variable)
1519 : id(id), variable(variable) {}
1520 };
1521 struct CompareById {
operator ()VariableWatcher::CompareById1522 bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1523 return lhs.id < rhs.id;
1524 }
1525 };
1526
1527 tensorflow::mutex watched_variables_mu_;
1528 std::set<IdAndVariable, CompareById> watched_variables_
1529 TF_GUARDED_BY(watched_variables_mu_);
1530 };
1531
1532 class GradientTape
1533 : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1534 PyTapeTensor> {
1535 public:
GradientTape(bool persistent,bool watch_accessed_variables)1536 explicit GradientTape(bool persistent, bool watch_accessed_variables)
1537 : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1538 PyTapeTensor>(persistent),
1539 watch_accessed_variables_(watch_accessed_variables) {}
1540
~GradientTape()1541 virtual ~GradientTape() {}
1542
VariableAccessed(PyObject * v)1543 void VariableAccessed(PyObject* v) {
1544 if (watch_accessed_variables_) {
1545 WatchVariable(v);
1546 }
1547 }
1548
WatchVariable(PyObject * v)1549 void WatchVariable(PyObject* v) {
1550 int64_t id = variable_watcher_.WatchVariable(v);
1551
1552 if (!PyErr_Occurred()) {
1553 this->Watch(id);
1554 }
1555 }
1556
GetVariablesAsPyTuple()1557 PyObject* GetVariablesAsPyTuple() {
1558 return variable_watcher_.GetVariablesAsPyTuple();
1559 }
1560
1561 private:
1562 bool watch_accessed_variables_;
1563 VariableWatcher variable_watcher_;
1564 };
1565
1566 typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
1567 PyTapeTensor>
1568 ForwardAccumulator;
1569
1570 // Incremented when a GradientTape or accumulator is newly added to a set, and
1571 // used to enforce an ordering between them.
1572 std::atomic_uint_fast64_t tape_nesting_id_counter(0);
1573
1574 typedef struct {
1575 PyObject_HEAD
1576 /* Type-specific fields go here. */
1577 GradientTape* tape;
1578 // A nesting order between GradientTapes and ForwardAccumulators, used to
1579 // ensure that GradientTapes do not watch the products of outer
1580 // ForwardAccumulators.
1581 tensorflow::int64 nesting_id;
1582 } TFE_Py_Tape;
1583
TFE_Py_Tape_Delete(PyObject * tape)1584 static void TFE_Py_Tape_Delete(PyObject* tape) {
1585 delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1586 Py_TYPE(tape)->tp_free(tape);
1587 }
1588
1589 static PyTypeObject TFE_Py_Tape_Type = {
1590 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1591 sizeof(TFE_Py_Tape), /* tp_basicsize */
1592 0, /* tp_itemsize */
1593 &TFE_Py_Tape_Delete, /* tp_dealloc */
1594 #if PY_VERSION_HEX < 0x03080000
1595 nullptr, /* tp_print */
1596 #else
1597 0, /* tp_vectorcall_offset */
1598 #endif
1599 nullptr, /* tp_getattr */
1600 nullptr, /* tp_setattr */
1601 nullptr, /* tp_reserved */
1602 nullptr, /* tp_repr */
1603 nullptr, /* tp_as_number */
1604 nullptr, /* tp_as_sequence */
1605 nullptr, /* tp_as_mapping */
1606 nullptr, /* tp_hash */
1607 nullptr, /* tp_call */
1608 nullptr, /* tp_str */
1609 nullptr, /* tp_getattro */
1610 nullptr, /* tp_setattro */
1611 nullptr, /* tp_as_buffer */
1612 Py_TPFLAGS_DEFAULT, /* tp_flags */
1613 "TFE_Py_Tape objects", /* tp_doc */
1614 };
1615
1616 typedef struct {
1617 PyObject_HEAD
1618 /* Type-specific fields go here. */
1619 ForwardAccumulator* accumulator;
1620 // A nesting order between GradientTapes and ForwardAccumulators, used to
1621 // ensure that GradientTapes do not watch the products of outer
1622 // ForwardAccumulators.
1623 tensorflow::int64 nesting_id;
1624 } TFE_Py_ForwardAccumulator;
1625
TFE_Py_ForwardAccumulatorDelete(PyObject * accumulator)1626 static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
1627 delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
1628 Py_TYPE(accumulator)->tp_free(accumulator);
1629 }
1630
1631 static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
1632 PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
1633 sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
1634 0, /* tp_itemsize */
1635 &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
1636 #if PY_VERSION_HEX < 0x03080000
1637 nullptr, /* tp_print */
1638 #else
1639 0, /* tp_vectorcall_offset */
1640 #endif
1641 nullptr, /* tp_getattr */
1642 nullptr, /* tp_setattr */
1643 nullptr, /* tp_reserved */
1644 nullptr, /* tp_repr */
1645 nullptr, /* tp_as_number */
1646 nullptr, /* tp_as_sequence */
1647 nullptr, /* tp_as_mapping */
1648 nullptr, /* tp_hash */
1649 nullptr, /* tp_call */
1650 nullptr, /* tp_str */
1651 nullptr, /* tp_getattro */
1652 nullptr, /* tp_setattro */
1653 nullptr, /* tp_as_buffer */
1654 Py_TPFLAGS_DEFAULT, /* tp_flags */
1655 "TFE_Py_ForwardAccumulator objects", /* tp_doc */
1656 };
1657
1658 typedef struct {
1659 PyObject_HEAD
1660 /* Type-specific fields go here. */
1661 VariableWatcher* variable_watcher;
1662 } TFE_Py_VariableWatcher;
1663
TFE_Py_VariableWatcher_Delete(PyObject * variable_watcher)1664 static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
1665 delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
1666 ->variable_watcher;
1667 Py_TYPE(variable_watcher)->tp_free(variable_watcher);
1668 }
1669
1670 static PyTypeObject TFE_Py_VariableWatcher_Type = {
1671 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
1672 sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
1673 0, /* tp_itemsize */
1674 &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
1675 #if PY_VERSION_HEX < 0x03080000
1676 nullptr, /* tp_print */
1677 #else
1678 0, /* tp_vectorcall_offset */
1679 #endif
1680 nullptr, /* tp_getattr */
1681 nullptr, /* tp_setattr */
1682 nullptr, /* tp_reserved */
1683 nullptr, /* tp_repr */
1684 nullptr, /* tp_as_number */
1685 nullptr, /* tp_as_sequence */
1686 nullptr, /* tp_as_mapping */
1687 nullptr, /* tp_hash */
1688 nullptr, /* tp_call */
1689 nullptr, /* tp_str */
1690 nullptr, /* tp_getattro */
1691 nullptr, /* tp_setattro */
1692 nullptr, /* tp_as_buffer */
1693 Py_TPFLAGS_DEFAULT, /* tp_flags */
1694 "TFE_Py_VariableWatcher objects", /* tp_doc */
1695 };
1696
1697 // Note: in the current design no mutex is needed here because of the python
1698 // GIL, which is always held when any TFE_Py_* methods are called. We should
1699 // revisit this if/when decide to not hold the GIL while manipulating the tape
1700 // stack.
GetTapeSet()1701 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1702 thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
1703 tape_set = nullptr;
1704 if (tape_set == nullptr) {
1705 tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
1706 }
1707 return tape_set.get();
1708 }
1709
1710 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
GetVariableWatcherSet()1711 GetVariableWatcherSet() {
1712 thread_local std::unique_ptr<
1713 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
1714 variable_watcher_set = nullptr;
1715 if (variable_watcher_set == nullptr) {
1716 variable_watcher_set.reset(
1717 new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
1718 }
1719 return variable_watcher_set.get();
1720 }
1721
1722 // A linked hash set, where iteration is in insertion order.
1723 //
1724 // Nested accumulators rely on op recording happening in insertion order, so an
1725 // unordered data structure like CompactPointerSet is not suitable. Outer
1726 // accumulators need to observe operations first so they know to watch the inner
1727 // accumulator's jvp computation.
1728 //
1729 // Not thread safe.
1730 class AccumulatorSet {
1731 public:
1732 // Returns true if `element` was newly inserted, false if it already exists.
insert(TFE_Py_ForwardAccumulator * element)1733 bool insert(TFE_Py_ForwardAccumulator* element) {
1734 if (map_.find(element) != map_.end()) {
1735 return false;
1736 }
1737 ListType::iterator it = ordered_.insert(ordered_.end(), element);
1738 map_.insert(std::make_pair(element, it));
1739 return true;
1740 }
1741
erase(TFE_Py_ForwardAccumulator * element)1742 void erase(TFE_Py_ForwardAccumulator* element) {
1743 MapType::iterator existing = map_.find(element);
1744 if (existing == map_.end()) {
1745 return;
1746 }
1747 ListType::iterator list_position = existing->second;
1748 map_.erase(existing);
1749 ordered_.erase(list_position);
1750 }
1751
empty() const1752 bool empty() const { return ordered_.empty(); }
1753
size() const1754 size_t size() const { return ordered_.size(); }
1755
1756 private:
1757 typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
1758 typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
1759 ListType::iterator>
1760 MapType;
1761
1762 public:
1763 typedef ListType::const_iterator const_iterator;
1764 typedef ListType::const_reverse_iterator const_reverse_iterator;
1765
begin() const1766 const_iterator begin() const { return ordered_.begin(); }
end() const1767 const_iterator end() const { return ordered_.end(); }
1768
rbegin() const1769 const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
rend() const1770 const_reverse_iterator rend() const { return ordered_.rend(); }
1771
1772 private:
1773 MapType map_;
1774 ListType ordered_;
1775 };
1776
GetAccumulatorSet()1777 AccumulatorSet* GetAccumulatorSet() {
1778 thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
1779 if (accumulator_set == nullptr) {
1780 accumulator_set.reset(new AccumulatorSet);
1781 }
1782 return accumulator_set.get();
1783 }
1784
HasAccumulator()1785 inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
1786
HasGradientTape()1787 inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
1788
HasAccumulatorOrTape()1789 inline bool HasAccumulatorOrTape() {
1790 return HasGradientTape() || HasAccumulator();
1791 }
1792
1793 // A safe copy of a set, used for tapes and accumulators. The copy is not
1794 // affected by other python threads changing the set of active tapes.
1795 template <typename ContainerType>
1796 class SafeSetCopy {
1797 public:
SafeSetCopy(const ContainerType & to_copy)1798 explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
1799 for (auto* member : set_copy_) {
1800 Py_INCREF(member);
1801 }
1802 }
1803
~SafeSetCopy()1804 ~SafeSetCopy() {
1805 for (auto* member : set_copy_) {
1806 Py_DECREF(member);
1807 }
1808 }
1809
begin() const1810 typename ContainerType::const_iterator begin() const {
1811 return set_copy_.begin();
1812 }
1813
end() const1814 typename ContainerType::const_iterator end() const { return set_copy_.end(); }
1815
empty() const1816 bool empty() const { return set_copy_.empty(); }
size() const1817 size_t size() const { return set_copy_.size(); }
1818
1819 protected:
1820 ContainerType set_copy_;
1821 };
1822
1823 class SafeTapeSet
1824 : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
1825 public:
SafeTapeSet()1826 SafeTapeSet()
1827 : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
1828 *GetTapeSet()) {}
1829 };
1830
1831 class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
1832 public:
SafeAccumulatorSet()1833 SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
1834
rbegin() const1835 typename AccumulatorSet::const_reverse_iterator rbegin() const {
1836 return set_copy_.rbegin();
1837 }
1838
rend() const1839 typename AccumulatorSet::const_reverse_iterator rend() const {
1840 return set_copy_.rend();
1841 }
1842 };
1843
1844 class SafeVariableWatcherSet
1845 : public SafeSetCopy<
1846 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
1847 public:
SafeVariableWatcherSet()1848 SafeVariableWatcherSet()
1849 : SafeSetCopy<
1850 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
1851 *GetVariableWatcherSet()) {}
1852 };
1853
ThreadTapeIsStopped()1854 bool* ThreadTapeIsStopped() {
1855 thread_local bool thread_tape_is_stopped{false};
1856 return &thread_tape_is_stopped;
1857 }
1858
TFE_Py_TapeSetStopOnThread()1859 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1860
TFE_Py_TapeSetRestartOnThread()1861 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1862
TFE_Py_TapeSetIsStopped()1863 PyObject* TFE_Py_TapeSetIsStopped() {
1864 if (*ThreadTapeIsStopped()) {
1865 Py_RETURN_TRUE;
1866 }
1867 Py_RETURN_FALSE;
1868 }
1869
TFE_Py_TapeSetNew(PyObject * persistent,PyObject * watch_accessed_variables)1870 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1871 PyObject* watch_accessed_variables) {
1872 TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1873 if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1874 TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1875 tape->tape = new GradientTape(persistent == Py_True,
1876 watch_accessed_variables == Py_True);
1877 Py_INCREF(tape);
1878 tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1879 GetTapeSet()->insert(tape);
1880 return reinterpret_cast<PyObject*>(tape);
1881 }
1882
TFE_Py_TapeSetAdd(PyObject * tape)1883 void TFE_Py_TapeSetAdd(PyObject* tape) {
1884 Py_INCREF(tape);
1885 TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
1886 if (!GetTapeSet()->insert(tfe_tape).second) {
1887 // Already exists in the tape set.
1888 Py_DECREF(tape);
1889 } else {
1890 tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1891 }
1892 }
1893
TFE_Py_TapeSetIsEmpty()1894 PyObject* TFE_Py_TapeSetIsEmpty() {
1895 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
1896 Py_RETURN_TRUE;
1897 }
1898 Py_RETURN_FALSE;
1899 }
1900
TFE_Py_TapeSetRemove(PyObject * tape)1901 void TFE_Py_TapeSetRemove(PyObject* tape) {
1902 auto* stack = GetTapeSet();
1903 stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1904 // We kept a reference to the tape in the set to ensure it wouldn't get
1905 // deleted under us; cleaning it up here.
1906 Py_DECREF(tape);
1907 }
1908
MakeIntList(PyObject * list)1909 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
1910 if (list == Py_None) {
1911 return {};
1912 }
1913 PyObject* seq = PySequence_Fast(list, "expected a sequence");
1914 if (seq == nullptr) {
1915 return {};
1916 }
1917 int len = PySequence_Size(list);
1918 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1919 std::vector<tensorflow::int64> tensor_ids;
1920 tensor_ids.reserve(len);
1921 for (int i = 0; i < len; ++i) {
1922 PyObject* item = seq_array[i];
1923 #if PY_MAJOR_VERSION >= 3
1924 if (PyLong_Check(item)) {
1925 #else
1926 if (PyLong_Check(item) || PyInt_Check(item)) {
1927 #endif
1928 int64_t id = MakeInt(item);
1929 tensor_ids.push_back(id);
1930 } else {
1931 tensor_ids.push_back(-1);
1932 }
1933 }
1934 Py_DECREF(seq);
1935 return tensor_ids;
1936 }
1937
1938 // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
1939 // null. Returns true on success and false on a Python exception.
1940 bool TensorShapesAndDtypes(PyObject* tensors,
1941 std::vector<tensorflow::int64>* tensor_ids,
1942 std::vector<tensorflow::DataType>* dtypes) {
1943 tensorflow::Safe_PyObjectPtr seq(
1944 PySequence_Fast(tensors, "expected a sequence"));
1945 if (seq == nullptr) {
1946 return false;
1947 }
1948 int len = PySequence_Fast_GET_SIZE(seq.get());
1949 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
1950 tensor_ids->reserve(len);
1951 dtypes->reserve(len);
1952 for (int i = 0; i < len; ++i) {
1953 PyObject* item = seq_array[i];
1954 tensor_ids->push_back(FastTensorId(item));
1955 dtypes->push_back(tensorflow::PyTensor_DataType(item));
1956 }
1957 return true;
1958 }
1959
1960 bool TapeCouldPossiblyRecord(PyObject* tensors) {
1961 if (tensors == Py_None) {
1962 return false;
1963 }
1964 if (*ThreadTapeIsStopped()) {
1965 return false;
1966 }
1967 if (!HasAccumulatorOrTape()) {
1968 return false;
1969 }
1970 return true;
1971 }
1972
1973 bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
1974
1975 bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
1976
1977 PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
1978 if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
1979 Py_RETURN_FALSE;
1980 }
1981 // TODO(apassos) consider not building a list and changing the API to check
1982 // each tensor individually.
1983 std::vector<tensorflow::int64> tensor_ids;
1984 std::vector<tensorflow::DataType> dtypes;
1985 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
1986 return nullptr;
1987 }
1988 auto tape_set = *GetTapeSet();
1989 for (TFE_Py_Tape* tape : tape_set) {
1990 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
1991 Py_RETURN_TRUE;
1992 }
1993 }
1994
1995 Py_RETURN_FALSE;
1996 }
1997
1998 PyObject* TFE_Py_ForwardAccumulatorPushState() {
1999 auto forward_accumulators = *GetAccumulatorSet();
2000 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2001 accumulator->accumulator->PushState();
2002 }
2003 Py_RETURN_NONE;
2004 }
2005
2006 PyObject* TFE_Py_ForwardAccumulatorPopState() {
2007 auto forward_accumulators = *GetAccumulatorSet();
2008 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2009 accumulator->accumulator->PopState();
2010 }
2011 Py_RETURN_NONE;
2012 }
2013
2014 PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
2015 if (!TapeCouldPossiblyRecord(tensors)) {
2016 return GetPythonObjectFromInt(0);
2017 }
2018 std::vector<tensorflow::int64> tensor_ids;
2019 std::vector<tensorflow::DataType> dtypes;
2020 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2021 return nullptr;
2022 }
2023
2024 // If there is a persistent tape watching, or if there are multiple tapes
2025 // watching, we'll return immediately indicating that higher-order tape
2026 // gradients are possible.
2027 bool some_tape_watching = false;
2028 if (CouldBackprop()) {
2029 auto tape_set = *GetTapeSet();
2030 for (TFE_Py_Tape* tape : tape_set) {
2031 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2032 if (tape->tape->IsPersistent() || some_tape_watching) {
2033 // Either this is the second tape watching, or this tape is
2034 // persistent: higher-order gradients are possible.
2035 return GetPythonObjectFromInt(2);
2036 }
2037 some_tape_watching = true;
2038 }
2039 }
2040 }
2041 if (CouldForwardprop()) {
2042 auto forward_accumulators = *GetAccumulatorSet();
2043 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2044 if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
2045 if (some_tape_watching) {
2046 // This is the second tape watching: higher-order gradients are
2047 // possible. Note that there's no equivalent of persistence for
2048 // forward-mode.
2049 return GetPythonObjectFromInt(2);
2050 }
2051 some_tape_watching = true;
2052 }
2053 }
2054 }
2055 if (some_tape_watching) {
2056 // There's exactly one non-persistent tape. The user can request first-order
2057 // gradients but won't be able to get higher-order tape gradients.
2058 return GetPythonObjectFromInt(1);
2059 } else {
2060 // There are no tapes. The user can't request tape gradients.
2061 return GetPythonObjectFromInt(0);
2062 }
2063 }
2064
2065 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
2066 if (!CouldBackprop()) {
2067 return;
2068 }
2069 int64_t tensor_id = FastTensorId(tensor);
2070 if (PyErr_Occurred()) {
2071 return;
2072 }
2073 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
2074 }
2075
2076 bool ListContainsNone(PyObject* list) {
2077 if (list == Py_None) return true;
2078 tensorflow::Safe_PyObjectPtr seq(
2079 PySequence_Fast(list, "expected a sequence"));
2080 if (seq == nullptr) {
2081 return false;
2082 }
2083
2084 int len = PySequence_Size(list);
2085 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2086 for (int i = 0; i < len; ++i) {
2087 PyObject* item = seq_array[i];
2088 if (item == Py_None) return true;
2089 }
2090
2091 return false;
2092 }
2093
2094 // As an optimization, the tape generally keeps only the shape and dtype of
2095 // tensors, and uses this information to generate ones/zeros tensors. However,
2096 // some tensors require OnesLike/ZerosLike because their gradients do not match
2097 // their inference shape/dtype.
2098 bool DTypeNeedsHandleData(tensorflow::DataType dtype) {
2099 return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE;
2100 }
2101
2102 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
2103 if (EagerTensor_CheckExact(tensor)) {
2104 tensorflow::ImmediateExecutionTensorHandle* handle =
2105 tensorflow::unwrap(EagerTensor_Handle(tensor));
2106 int64_t id = PyEagerTensor_ID(tensor);
2107 tensorflow::DataType dtype =
2108 static_cast<tensorflow::DataType>(handle->DataType());
2109 if (DTypeNeedsHandleData(dtype)) {
2110 return PyTapeTensor(id, dtype, tensor);
2111 }
2112
2113 tensorflow::TensorShape tensor_shape;
2114 int num_dims;
2115 tensorflow::Status status = handle->NumDims(&num_dims);
2116 if (status.ok()) {
2117 for (int i = 0; i < num_dims; ++i) {
2118 int64_t dim_size;
2119 status = handle->Dim(i, &dim_size);
2120 if (!status.ok()) break;
2121 tensor_shape.AddDim(dim_size);
2122 }
2123 }
2124
2125 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2126 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2127 tensorflow::TensorShape({}));
2128 } else {
2129 return PyTapeTensor(id, dtype, tensor_shape);
2130 }
2131 }
2132 int64_t id = FastTensorId(tensor);
2133 if (PyErr_Occurred()) {
2134 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2135 tensorflow::TensorShape({}));
2136 }
2137 PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
2138 PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
2139 Py_DECREF(dtype_object);
2140 tensorflow::DataType dtype =
2141 static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
2142 Py_DECREF(dtype_enum);
2143 if (PyErr_Occurred()) {
2144 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2145 tensorflow::TensorShape({}));
2146 }
2147 static char _shape_tuple[] = "_shape_tuple";
2148 tensorflow::Safe_PyObjectPtr shape_tuple(
2149 PyObject_CallMethod(tensor, _shape_tuple, nullptr));
2150 if (PyErr_Occurred()) {
2151 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2152 tensorflow::TensorShape({}));
2153 }
2154
2155 if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) {
2156 return PyTapeTensor(id, dtype, tensor);
2157 }
2158
2159 auto l = MakeIntList(shape_tuple.get());
2160 // Replace -1, which represents accidental Nones which can occur in graph mode
2161 // and can cause errors in shape construction with 0s.
2162 for (auto& c : l) {
2163 if (c < 0) {
2164 c = 0;
2165 }
2166 }
2167 tensorflow::TensorShape shape(l);
2168 return PyTapeTensor(id, dtype, shape);
2169 }
2170
2171 // Populates output_info from output_seq, which must come from PySequence_Fast.
2172 //
2173 // Does not take ownership of output_seq. Returns true on success and false if a
2174 // Python exception has been set.
2175 bool TapeTensorsFromTensorSequence(PyObject* output_seq,
2176 std::vector<PyTapeTensor>* output_info) {
2177 Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
2178 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2179 output_info->reserve(output_len);
2180 for (Py_ssize_t i = 0; i < output_len; ++i) {
2181 output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
2182 if (PyErr_Occurred() != nullptr) {
2183 return false;
2184 }
2185 }
2186 return true;
2187 }
2188
2189 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
2190 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2191 if (seq == nullptr) {
2192 return {};
2193 }
2194 int len = PySequence_Fast_GET_SIZE(seq);
2195 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2196 std::vector<tensorflow::int64> list;
2197 list.reserve(len);
2198 for (int i = 0; i < len; ++i) {
2199 PyObject* tensor = seq_array[i];
2200 list.push_back(FastTensorId(tensor));
2201 if (PyErr_Occurred()) {
2202 Py_DECREF(seq);
2203 return list;
2204 }
2205 }
2206 Py_DECREF(seq);
2207 return list;
2208 }
2209
2210 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
2211 if (!CouldBackprop()) {
2212 return;
2213 }
2214 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2215 tape->tape->VariableAccessed(variable);
2216 }
2217 }
2218
2219 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
2220 if (!CouldBackprop()) {
2221 return;
2222 }
2223 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
2224 }
2225
2226 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
2227 return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
2228 }
2229
2230 PyObject* TFE_Py_VariableWatcherNew() {
2231 TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
2232 if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
2233 TFE_Py_VariableWatcher* variable_watcher =
2234 PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
2235 variable_watcher->variable_watcher = new VariableWatcher();
2236 Py_INCREF(variable_watcher);
2237 GetVariableWatcherSet()->insert(variable_watcher);
2238 return reinterpret_cast<PyObject*>(variable_watcher);
2239 }
2240
2241 void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
2242 auto* stack = GetVariableWatcherSet();
2243 stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
2244 // We kept a reference to the variable watcher in the set to ensure it
2245 // wouldn't get deleted under us; cleaning it up here.
2246 Py_DECREF(variable_watcher);
2247 }
2248
2249 void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
2250 for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
2251 variable_watcher->variable_watcher->WatchVariable(variable);
2252 }
2253 }
2254
2255 PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
2256 return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
2257 ->variable_watcher->GetVariablesAsPyTuple();
2258 }
2259
2260 namespace {
2261 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
2262 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2263 if (seq == nullptr) {
2264 return {};
2265 }
2266 int len = PySequence_Fast_GET_SIZE(seq);
2267 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2268 std::vector<tensorflow::DataType> list;
2269 list.reserve(len);
2270 for (int i = 0; i < len; ++i) {
2271 PyObject* tensor = seq_array[i];
2272 list.push_back(tensorflow::PyTensor_DataType(tensor));
2273 }
2274 Py_DECREF(seq);
2275 return list;
2276 }
2277
2278 PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
2279 PyObject* weak_tensor_ref) {
2280 int64_t parsed_tensor_id = MakeInt(tensor_id);
2281 for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
2282 accumulator->accumulator->DeleteGradient(parsed_tensor_id);
2283 }
2284 Py_DECREF(weak_tensor_ref);
2285 Py_DECREF(tensor_id);
2286 Py_INCREF(Py_None);
2287 return Py_None;
2288 }
2289
2290 static PyMethodDef forward_accumulator_delete_gradient_method_def = {
2291 "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
2292 METH_O, "ForwardAccumulatorDeleteGradient"};
2293
2294 void RegisterForwardAccumulatorCleanup(PyObject* tensor, int64_t tensor_id) {
2295 tensorflow::Safe_PyObjectPtr callback(
2296 PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
2297 PyLong_FromLong(tensor_id)));
2298 // We need to keep a reference to the weakref active if we want our callback
2299 // called. The callback itself now owns the weakref object and the tensor ID
2300 // object.
2301 PyWeakref_NewRef(tensor, callback.get());
2302 }
2303
2304 void TapeSetRecordBackprop(
2305 const string& op_type, const std::vector<PyTapeTensor>& output_info,
2306 const std::vector<tensorflow::int64>& input_ids,
2307 const std::vector<tensorflow::DataType>& input_dtypes,
2308 const std::function<PyBackwardFunction*()>& backward_function_getter,
2309 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2310 tensorflow::uint64 max_gradient_tape_id) {
2311 if (!CouldBackprop()) {
2312 return;
2313 }
2314 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2315 if (tape->nesting_id < max_gradient_tape_id) {
2316 tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
2317 backward_function_getter,
2318 backward_function_killer);
2319 }
2320 }
2321 }
2322
2323 bool TapeSetRecordForwardprop(
2324 const string& op_type, PyObject* output_seq,
2325 const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
2326 const std::vector<tensorflow::int64>& input_ids,
2327 const std::vector<tensorflow::DataType>& input_dtypes,
2328 const std::function<PyBackwardFunction*()>& backward_function_getter,
2329 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2330 const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
2331 PyObject* forwardprop_output_indices,
2332 tensorflow::uint64* max_gradient_tape_id) {
2333 *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
2334 if (!CouldForwardprop()) {
2335 return true;
2336 }
2337 auto accumulator_set = SafeAccumulatorSet();
2338 tensorflow::Safe_PyObjectPtr input_seq(
2339 PySequence_Fast(input_tensors, "expected a sequence of tensors"));
2340 if (input_seq == nullptr || PyErr_Occurred()) return false;
2341 Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
2342 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2343 for (int i = 0; i < output_info.size(); ++i) {
2344 RegisterForwardAccumulatorCleanup(output_seq_array[i],
2345 output_info[i].GetID());
2346 }
2347 if (forwardprop_output_indices != nullptr &&
2348 forwardprop_output_indices != Py_None) {
2349 tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
2350 forwardprop_output_indices, "Expected a sequence of indices"));
2351 if (indices_fast == nullptr || PyErr_Occurred()) {
2352 return false;
2353 }
2354 if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
2355 accumulator_set.size()) {
2356 MaybeRaiseExceptionFromStatus(
2357 tensorflow::errors::Internal(
2358 "Accumulators were added or removed from the active set "
2359 "between packing and unpacking."),
2360 nullptr);
2361 }
2362 PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
2363 Py_ssize_t accumulator_index = 0;
2364 for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
2365 it != accumulator_set.rend(); ++it, ++accumulator_index) {
2366 tensorflow::Safe_PyObjectPtr jvp_index_seq(
2367 PySequence_Fast(indices_fast_array[accumulator_index],
2368 "Expected a sequence of jvp indices."));
2369 if (jvp_index_seq == nullptr || PyErr_Occurred()) {
2370 return false;
2371 }
2372 Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
2373 PyObject** jvp_index_seq_array =
2374 PySequence_Fast_ITEMS(jvp_index_seq.get());
2375 for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
2376 PyObject* tuple = jvp_index_seq_array[jvp_index];
2377 int64_t primal_tensor_id =
2378 output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
2379 (*it)->accumulator->Watch(
2380 primal_tensor_id,
2381 output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
2382 }
2383 }
2384 } else {
2385 std::vector<PyTapeTensor> input_info;
2386 input_info.reserve(input_len);
2387 PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
2388 for (Py_ssize_t i = 0; i < input_len; ++i) {
2389 input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
2390 }
2391 for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
2392 tensorflow::Status status = accumulator->accumulator->Accumulate(
2393 op_type, input_info, output_info, input_ids, input_dtypes,
2394 forward_function, backward_function_getter, backward_function_killer);
2395 if (PyErr_Occurred()) return false; // Don't swallow Python exceptions.
2396 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2397 return false;
2398 }
2399 if (accumulator->accumulator->BusyAccumulating()) {
2400 // Ensure inner accumulators don't see outer accumulators' jvps. This
2401 // mostly happens on its own, with some potentially surprising
2402 // exceptions, so the blanket policy is for consistency.
2403 *max_gradient_tape_id = accumulator->nesting_id;
2404 break;
2405 }
2406 }
2407 }
2408 return true;
2409 }
2410
2411 PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
2412 PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
2413 for (int i = 0; i < input_tangents.size(); ++i) {
2414 PyObject* element;
2415 if (input_tangents[i] == nullptr) {
2416 element = Py_None;
2417 } else {
2418 element = input_tangents[i];
2419 }
2420 Py_INCREF(element);
2421 PyTuple_SET_ITEM(py_input_tangents, i, element);
2422 }
2423 return py_input_tangents;
2424 }
2425
2426 tensorflow::Status ParseTangentOutputs(
2427 PyObject* user_output, std::vector<PyObject*>* output_tangents) {
2428 if (user_output == Py_None) {
2429 // No connected gradients.
2430 return tensorflow::Status::OK();
2431 }
2432 tensorflow::Safe_PyObjectPtr fast_result(
2433 PySequence_Fast(user_output, "expected a sequence of forward gradients"));
2434 if (fast_result == nullptr) {
2435 return tensorflow::errors::InvalidArgument(
2436 "forward gradient function did not return a sequence.");
2437 }
2438 int len = PySequence_Fast_GET_SIZE(fast_result.get());
2439 PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
2440 output_tangents->reserve(len);
2441 for (int i = 0; i < len; ++i) {
2442 PyObject* item = fast_result_array[i];
2443 if (item == Py_None) {
2444 output_tangents->push_back(nullptr);
2445 } else {
2446 Py_INCREF(item);
2447 output_tangents->push_back(item);
2448 }
2449 }
2450 return tensorflow::Status::OK();
2451 }
2452
2453 // Calls the registered forward_gradient_function, computing `output_tangents`
2454 // from `input_tangents`. `output_tangents` must not be null.
2455 //
2456 // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
2457 // the forward function is being called.
2458 tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
2459 PyObject* inputs, PyObject* results,
2460 const std::vector<PyObject*>& input_tangents,
2461 std::vector<PyObject*>* output_tangents,
2462 bool use_batch) {
2463 if (forward_gradient_function == nullptr) {
2464 return tensorflow::errors::Internal(
2465 "No forward gradient function registered.");
2466 }
2467 tensorflow::Safe_PyObjectPtr py_input_tangents(
2468 TangentsAsPyTuple(input_tangents));
2469
2470 // Normalize the input sequence to a tuple so it works with function
2471 // caching; otherwise it may be an opaque _InputList object.
2472 tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
2473 PyObject* to_batch = (use_batch) ? Py_True : Py_False;
2474 tensorflow::Safe_PyObjectPtr callback_args(
2475 Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
2476 py_input_tangents.get(), to_batch));
2477 tensorflow::Safe_PyObjectPtr py_result(
2478 PyObject_CallObject(forward_gradient_function, callback_args.get()));
2479 if (py_result == nullptr || PyErr_Occurred()) {
2480 return tensorflow::errors::Internal(
2481 "forward gradient function threw exceptions");
2482 }
2483 return ParseTangentOutputs(py_result.get(), output_tangents);
2484 }
2485
2486 // Like CallJVPFunction, but calls a pre-bound forward function.
2487 // These are passed in from a record_gradient argument.
2488 tensorflow::Status CallOpSpecificJVPFunction(
2489 PyObject* op_specific_forward_function,
2490 const std::vector<PyObject*>& input_tangents,
2491 std::vector<PyObject*>* output_tangents) {
2492 tensorflow::Safe_PyObjectPtr py_input_tangents(
2493 TangentsAsPyTuple(input_tangents));
2494
2495 tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
2496 op_specific_forward_function, py_input_tangents.get()));
2497 if (py_result == nullptr || PyErr_Occurred()) {
2498 return tensorflow::errors::Internal(
2499 "forward gradient function threw exceptions");
2500 }
2501 return ParseTangentOutputs(py_result.get(), output_tangents);
2502 }
2503
2504 bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
2505 if (PyBytes_Check(op_type)) {
2506 *op_type_string = PyBytes_AsString(op_type);
2507 } else if (PyUnicode_Check(op_type)) {
2508 #if PY_MAJOR_VERSION >= 3
2509 *op_type_string = PyUnicode_AsUTF8(op_type);
2510 #else
2511 PyObject* py_str = PyUnicode_AsUTF8String(op_type);
2512 if (py_str == nullptr) {
2513 return false;
2514 }
2515 *op_type_string = PyBytes_AS_STRING(py_str);
2516 Py_DECREF(py_str);
2517 #endif
2518 } else {
2519 PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
2520 return false;
2521 }
2522 return true;
2523 }
2524
2525 bool TapeSetRecordOperation(
2526 PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
2527 const std::vector<tensorflow::int64>& input_ids,
2528 const std::vector<tensorflow::DataType>& input_dtypes,
2529 const std::function<PyBackwardFunction*()>& backward_function_getter,
2530 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2531 const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
2532 std::vector<PyTapeTensor> output_info;
2533 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2534 output_tensors, "expected a sequence of integer tensor ids"));
2535 if (PyErr_Occurred() ||
2536 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2537 return false;
2538 }
2539 string op_type_str;
2540 if (!ParseOpTypeString(op_type, &op_type_str)) {
2541 return false;
2542 }
2543 tensorflow::uint64 max_gradient_tape_id;
2544 if (!TapeSetRecordForwardprop(
2545 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2546 input_dtypes, backward_function_getter, backward_function_killer,
2547 forward_function, nullptr /* No special-cased jvps. */,
2548 &max_gradient_tape_id)) {
2549 return false;
2550 }
2551 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2552 backward_function_getter, backward_function_killer,
2553 max_gradient_tape_id);
2554 return true;
2555 }
2556 } // namespace
2557
2558 PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
2559 PyObject* output_tensors,
2560 PyObject* input_tensors,
2561 PyObject* backward_function,
2562 PyObject* forward_function) {
2563 if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
2564 Py_RETURN_NONE;
2565 }
2566 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2567 if (PyErr_Occurred()) return nullptr;
2568
2569 std::vector<tensorflow::DataType> input_dtypes =
2570 MakeTensorDtypeList(input_tensors);
2571 if (PyErr_Occurred()) return nullptr;
2572
2573 std::function<PyBackwardFunction*()> backward_function_getter(
2574 [backward_function]() {
2575 Py_INCREF(backward_function);
2576 PyBackwardFunction* function = new PyBackwardFunction(
2577 [backward_function](PyObject* out_grads,
2578 const std::vector<tensorflow::int64>& unused) {
2579 return PyObject_CallObject(backward_function, out_grads);
2580 });
2581 return function;
2582 });
2583 std::function<void(PyBackwardFunction*)> backward_function_killer(
2584 [backward_function](PyBackwardFunction* py_backward_function) {
2585 Py_DECREF(backward_function);
2586 delete py_backward_function;
2587 });
2588
2589 if (forward_function == Py_None) {
2590 if (!TapeSetRecordOperation(
2591 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2592 backward_function_getter, backward_function_killer,
2593 nullptr /* No special-cased forward function */)) {
2594 return nullptr;
2595 }
2596 } else {
2597 tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
2598 [forward_function](const std::vector<PyObject*>& input_tangents,
2599 std::vector<PyObject*>* output_tangents,
2600 bool use_batch = false) {
2601 return CallOpSpecificJVPFunction(forward_function, input_tangents,
2602 output_tangents);
2603 });
2604 if (!TapeSetRecordOperation(
2605 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2606 backward_function_getter, backward_function_killer,
2607 &wrapped_forward_function)) {
2608 return nullptr;
2609 }
2610 }
2611 Py_RETURN_NONE;
2612 }
2613
2614 PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
2615 PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
2616 PyObject* backward_function, PyObject* forwardprop_output_indices) {
2617 if (!HasAccumulator() || *ThreadTapeIsStopped()) {
2618 Py_RETURN_NONE;
2619 }
2620 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2621 if (PyErr_Occurred()) return nullptr;
2622
2623 std::vector<tensorflow::DataType> input_dtypes =
2624 MakeTensorDtypeList(input_tensors);
2625 if (PyErr_Occurred()) return nullptr;
2626
2627 std::function<PyBackwardFunction*()> backward_function_getter(
2628 [backward_function]() {
2629 Py_INCREF(backward_function);
2630 PyBackwardFunction* function = new PyBackwardFunction(
2631 [backward_function](PyObject* out_grads,
2632 const std::vector<tensorflow::int64>& unused) {
2633 return PyObject_CallObject(backward_function, out_grads);
2634 });
2635 return function;
2636 });
2637 std::function<void(PyBackwardFunction*)> backward_function_killer(
2638 [backward_function](PyBackwardFunction* py_backward_function) {
2639 Py_DECREF(backward_function);
2640 delete py_backward_function;
2641 });
2642 std::vector<PyTapeTensor> output_info;
2643 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2644 output_tensors, "expected a sequence of integer tensor ids"));
2645 if (PyErr_Occurred() ||
2646 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2647 return nullptr;
2648 }
2649 string op_type_str;
2650 if (!ParseOpTypeString(op_type, &op_type_str)) {
2651 return nullptr;
2652 }
2653 tensorflow::uint64 max_gradient_tape_id;
2654 if (!TapeSetRecordForwardprop(
2655 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2656 input_dtypes, backward_function_getter, backward_function_killer,
2657 nullptr /* no special-cased forward function */,
2658 forwardprop_output_indices, &max_gradient_tape_id)) {
2659 return nullptr;
2660 }
2661 Py_RETURN_NONE;
2662 }
2663
2664 PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
2665 PyObject* output_tensors,
2666 PyObject* input_tensors,
2667 PyObject* backward_function) {
2668 if (!CouldBackprop()) {
2669 Py_RETURN_NONE;
2670 }
2671 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2672 if (PyErr_Occurred()) return nullptr;
2673
2674 std::vector<tensorflow::DataType> input_dtypes =
2675 MakeTensorDtypeList(input_tensors);
2676 if (PyErr_Occurred()) return nullptr;
2677
2678 std::function<PyBackwardFunction*()> backward_function_getter(
2679 [backward_function]() {
2680 Py_INCREF(backward_function);
2681 PyBackwardFunction* function = new PyBackwardFunction(
2682 [backward_function](PyObject* out_grads,
2683 const std::vector<tensorflow::int64>& unused) {
2684 return PyObject_CallObject(backward_function, out_grads);
2685 });
2686 return function;
2687 });
2688 std::function<void(PyBackwardFunction*)> backward_function_killer(
2689 [backward_function](PyBackwardFunction* py_backward_function) {
2690 Py_DECREF(backward_function);
2691 delete py_backward_function;
2692 });
2693 std::vector<PyTapeTensor> output_info;
2694 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2695 output_tensors, "expected a sequence of integer tensor ids"));
2696 if (PyErr_Occurred() ||
2697 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2698 return nullptr;
2699 }
2700 string op_type_str;
2701 if (!ParseOpTypeString(op_type, &op_type_str)) {
2702 return nullptr;
2703 }
2704 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2705 backward_function_getter, backward_function_killer,
2706 // No filtering based on relative ordering with forward
2707 // accumulators.
2708 std::numeric_limits<tensorflow::uint64>::max());
2709 Py_RETURN_NONE;
2710 }
2711
2712 void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) {
2713 for (TFE_Py_Tape* tape : *GetTapeSet()) {
2714 tape->tape->DeleteTrace(tensor_id);
2715 }
2716 }
2717
2718 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
2719 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2720 if (seq == nullptr) {
2721 return {};
2722 }
2723 int len = PySequence_Fast_GET_SIZE(seq);
2724 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2725 std::vector<PyObject*> list(seq_array, seq_array + len);
2726 Py_DECREF(seq);
2727 return list;
2728 }
2729
2730 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
2731 PyObject* sources, PyObject* output_gradients,
2732 PyObject* sources_raw,
2733 PyObject* unconnected_gradients,
2734 TF_Status* status) {
2735 TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
2736 if (!tape_obj->tape->IsPersistent()) {
2737 auto* tape_set = GetTapeSet();
2738 if (tape_set->find(tape_obj) != tape_set->end()) {
2739 PyErr_SetString(PyExc_RuntimeError,
2740 "gradient() cannot be invoked within the "
2741 "GradientTape context (i.e., while operations are being "
2742 "recorded). Either move the call to gradient() to be "
2743 "outside the 'with tf.GradientTape' block, or "
2744 "use a persistent tape: "
2745 "'with tf.GradientTape(persistent=true)'");
2746 return nullptr;
2747 }
2748 }
2749
2750 std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
2751 if (PyErr_Occurred()) {
2752 return nullptr;
2753 }
2754 std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
2755 if (PyErr_Occurred()) {
2756 return nullptr;
2757 }
2758 tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
2759 sources_vec.end());
2760
2761 tensorflow::Safe_PyObjectPtr seq =
2762 tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
2763 int len = PySequence_Fast_GET_SIZE(seq.get());
2764 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2765 std::unordered_map<tensorflow::int64, PyTapeTensor>
2766 source_tensors_that_are_targets;
2767 for (int i = 0; i < len; ++i) {
2768 int64_t target_id = target_vec[i];
2769 if (sources_set.find(target_id) != sources_set.end()) {
2770 auto tensor = seq_array[i];
2771 source_tensors_that_are_targets.insert(
2772 std::make_pair(target_id, TapeTensorFromTensor(tensor)));
2773 }
2774 if (PyErr_Occurred()) {
2775 return nullptr;
2776 }
2777 }
2778 if (PyErr_Occurred()) {
2779 return nullptr;
2780 }
2781
2782 std::vector<PyObject*> outgrad_vec;
2783 if (output_gradients != Py_None) {
2784 outgrad_vec = MakeTensorList(output_gradients);
2785 if (PyErr_Occurred()) {
2786 return nullptr;
2787 }
2788 for (PyObject* tensor : outgrad_vec) {
2789 // Calling the backward function will eat a reference to the tensors in
2790 // outgrad_vec, so we need to increase their reference count.
2791 Py_INCREF(tensor);
2792 }
2793 }
2794 std::vector<PyObject*> result(sources_vec.size());
2795 status->status = tape_obj->tape->ComputeGradient(
2796 *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
2797 outgrad_vec, absl::MakeSpan(result));
2798 if (!status->status.ok()) {
2799 if (PyErr_Occurred()) {
2800 // Do not propagate the erroneous status as that would swallow the
2801 // exception which caused the problem.
2802 status->status = tensorflow::Status::OK();
2803 }
2804 return nullptr;
2805 }
2806
2807 bool unconnected_gradients_zero =
2808 strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
2809 std::vector<PyObject*> sources_obj;
2810 if (unconnected_gradients_zero) {
2811 // Uses the "raw" sources here so it can properly make a zeros tensor even
2812 // if there are resource variables as sources.
2813 sources_obj = MakeTensorList(sources_raw);
2814 }
2815
2816 if (!result.empty()) {
2817 PyObject* py_result = PyList_New(result.size());
2818 tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
2819 for (int i = 0; i < result.size(); ++i) {
2820 if (result[i] == nullptr) {
2821 if (unconnected_gradients_zero) {
2822 // generate a zeros tensor in the shape of sources[i]
2823 tensorflow::DataType dtype =
2824 tensorflow::PyTensor_DataType(sources_obj[i]);
2825 PyTapeTensor tensor =
2826 PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
2827 result[i] = tensor.ZerosLike();
2828 } else {
2829 Py_INCREF(Py_None);
2830 result[i] = Py_None;
2831 }
2832 } else if (seen_results.find(result[i]) != seen_results.end()) {
2833 Py_INCREF(result[i]);
2834 }
2835 seen_results.insert(result[i]);
2836 PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
2837 }
2838 return py_result;
2839 }
2840 return PyList_New(0);
2841 }
2842
2843 PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
2844 TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
2845 if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
2846 TFE_Py_ForwardAccumulator* accumulator =
2847 PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
2848 if (py_vspace == nullptr) {
2849 MaybeRaiseExceptionFromStatus(
2850 tensorflow::errors::Internal(
2851 "ForwardAccumulator requires a PyVSpace to be registered."),
2852 nullptr);
2853 }
2854 accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
2855 return reinterpret_cast<PyObject*>(accumulator);
2856 }
2857
2858 PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
2859 TFE_Py_ForwardAccumulator* c_accumulator(
2860 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2861 c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
2862 if (GetAccumulatorSet()->insert(c_accumulator)) {
2863 Py_INCREF(accumulator);
2864 Py_RETURN_NONE;
2865 } else {
2866 MaybeRaiseExceptionFromStatus(
2867 tensorflow::errors::Internal(
2868 "A ForwardAccumulator was added to the active set twice."),
2869 nullptr);
2870 return nullptr;
2871 }
2872 }
2873
2874 void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
2875 GetAccumulatorSet()->erase(
2876 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2877 Py_DECREF(accumulator);
2878 }
2879
2880 void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
2881 PyObject* tangent) {
2882 int64_t tensor_id = FastTensorId(tensor);
2883 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2884 ->accumulator->Watch(tensor_id, tangent);
2885 RegisterForwardAccumulatorCleanup(tensor, tensor_id);
2886 }
2887
2888 // Returns a new reference to the JVP Tensor.
2889 PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
2890 PyObject* tensor) {
2891 PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2892 ->accumulator->FetchJVP(FastTensorId(tensor));
2893 if (jvp == nullptr) {
2894 jvp = Py_None;
2895 }
2896 Py_INCREF(jvp);
2897 return jvp;
2898 }
2899
2900 PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
2901 if (!TapeCouldPossiblyRecord(tensors)) {
2902 tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
2903 tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
2904 return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
2905 }
2906 auto accumulators = *GetAccumulatorSet();
2907 tensorflow::Safe_PyObjectPtr tensors_fast(
2908 PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
2909 if (tensors_fast == nullptr || PyErr_Occurred()) {
2910 return nullptr;
2911 }
2912 std::vector<tensorflow::int64> augmented_input_ids;
2913 int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
2914 PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
2915 for (Py_ssize_t position = 0; position < len; ++position) {
2916 PyObject* input = tensors_fast_array[position];
2917 if (input == Py_None) {
2918 continue;
2919 }
2920 tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
2921 if (input_dtype == tensorflow::DT_INVALID) {
2922 return nullptr;
2923 }
2924 augmented_input_ids.push_back(FastTensorId(input));
2925 }
2926 if (PyErr_Occurred()) {
2927 return nullptr;
2928 }
2929 // Find the innermost accumulator such that all outer accumulators are
2930 // recording. Any more deeply nested accumulators will not have their JVPs
2931 // saved.
2932 AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
2933 for (; innermost_all_recording != accumulators.end();
2934 ++innermost_all_recording) {
2935 if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
2936 break;
2937 }
2938 }
2939 AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
2940 innermost_all_recording);
2941
2942 bool saving_jvps = false;
2943 tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
2944 std::vector<PyObject*> new_tensors;
2945 Py_ssize_t accumulator_index = 0;
2946 // Start with the innermost accumulators to give outer accumulators a chance
2947 // to find their higher-order JVPs.
2948 for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
2949 it != accumulators.rend(); ++it, ++accumulator_index) {
2950 std::vector<tensorflow::int64> new_input_ids;
2951 std::vector<std::pair<tensorflow::int64, tensorflow::int64>>
2952 accumulator_indices;
2953 if (it == reverse_innermost_all_recording) {
2954 saving_jvps = true;
2955 }
2956 if (saving_jvps) {
2957 for (int input_index = 0; input_index < augmented_input_ids.size();
2958 ++input_index) {
2959 int64_t existing_input = augmented_input_ids[input_index];
2960 PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
2961 if (jvp != nullptr) {
2962 new_tensors.push_back(jvp);
2963 new_input_ids.push_back(FastTensorId(jvp));
2964 accumulator_indices.emplace_back(
2965 input_index,
2966 augmented_input_ids.size() + new_input_ids.size() - 1);
2967 }
2968 }
2969 }
2970 tensorflow::Safe_PyObjectPtr accumulator_indices_py(
2971 PyTuple_New(accumulator_indices.size()));
2972 for (int i = 0; i < accumulator_indices.size(); ++i) {
2973 tensorflow::Safe_PyObjectPtr from_index(
2974 GetPythonObjectFromInt(accumulator_indices[i].first));
2975 tensorflow::Safe_PyObjectPtr to_index(
2976 GetPythonObjectFromInt(accumulator_indices[i].second));
2977 PyTuple_SetItem(accumulator_indices_py.get(), i,
2978 PyTuple_Pack(2, from_index.get(), to_index.get()));
2979 }
2980 PyTuple_SetItem(all_indices.get(), accumulator_index,
2981 accumulator_indices_py.release());
2982 augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
2983 new_input_ids.end());
2984 }
2985
2986 tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
2987 for (int i = 0; i < new_tensors.size(); ++i) {
2988 PyObject* jvp = new_tensors[i];
2989 Py_INCREF(jvp);
2990 PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
2991 }
2992 return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
2993 }
2994
2995 namespace {
2996
2997 // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
2998 enum FastPathExecuteArgIndex {
2999 FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
3000 FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
3001 FAST_PATH_EXECUTE_ARG_NAME = 2,
3002 FAST_PATH_EXECUTE_ARG_INPUT_START = 3
3003 };
3004
3005 PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
3006 #if PY_MAJOR_VERSION >= 3
3007 return PyUnicode_FromStringAndSize(s.data(), s.size());
3008 #else
3009 return PyBytes_FromStringAndSize(s.data(), s.size());
3010 #endif
3011 }
3012
3013 bool CheckResourceVariable(PyObject* item) {
3014 if (tensorflow::swig::IsResourceVariable(item)) {
3015 tensorflow::Safe_PyObjectPtr handle(
3016 PyObject_GetAttrString(item, "_handle"));
3017 return EagerTensor_CheckExact(handle.get());
3018 }
3019
3020 return false;
3021 }
3022
3023 bool IsNumberType(PyObject* item) {
3024 #if PY_MAJOR_VERSION >= 3
3025 return PyFloat_Check(item) || PyLong_Check(item);
3026 #else
3027 return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
3028 #endif
3029 }
3030
3031 bool CheckOneInput(PyObject* item) {
3032 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
3033 PyArray_Check(item) || IsNumberType(item)) {
3034 return true;
3035 }
3036
3037 // Sequences are not properly handled. Sequences with purely python numeric
3038 // types work, but sequences with mixes of EagerTensors and python numeric
3039 // types don't work.
3040 // TODO(nareshmodi): fix
3041 return false;
3042 }
3043
3044 bool CheckInputsOk(PyObject* seq, int start_index,
3045 const tensorflow::OpDef& op_def) {
3046 for (int i = 0; i < op_def.input_arg_size(); i++) {
3047 PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
3048 if (!op_def.input_arg(i).number_attr().empty() ||
3049 !op_def.input_arg(i).type_list_attr().empty()) {
3050 // This item should be a seq input.
3051 if (!PySequence_Check(item)) {
3052 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3053 << "\", Input \"" << op_def.input_arg(i).name()
3054 << "\" since we expected a sequence, but got "
3055 << item->ob_type->tp_name;
3056 return false;
3057 }
3058 tensorflow::Safe_PyObjectPtr fast_item(
3059 PySequence_Fast(item, "Could not parse sequence."));
3060 if (fast_item.get() == nullptr) {
3061 return false;
3062 }
3063 int len = PySequence_Fast_GET_SIZE(fast_item.get());
3064 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3065 for (Py_ssize_t j = 0; j < len; j++) {
3066 PyObject* inner_item = fast_item_array[j];
3067 if (!CheckOneInput(inner_item)) {
3068 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3069 << "\", Input \"" << op_def.input_arg(i).name()
3070 << "\", Index " << j
3071 << " since we expected an EagerTensor/ResourceVariable, "
3072 "but got "
3073 << inner_item->ob_type->tp_name;
3074 return false;
3075 }
3076 }
3077 } else if (!CheckOneInput(item)) {
3078 VLOG(1)
3079 << "Falling back to slow path for Op \"" << op_def.name()
3080 << "\", Input \"" << op_def.input_arg(i).name()
3081 << "\" since we expected an EagerTensor/ResourceVariable, but got "
3082 << item->ob_type->tp_name;
3083 return false;
3084 }
3085 }
3086
3087 return true;
3088 }
3089
3090 tensorflow::DataType MaybeGetDType(PyObject* item) {
3091 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
3092 return tensorflow::PyTensor_DataType(item);
3093 }
3094
3095 return tensorflow::DT_INVALID;
3096 }
3097
3098 tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
3099 FastPathOpExecInfo* op_exec_info) {
3100 auto cached_it = op_exec_info->cached_dtypes.find(attr);
3101 if (cached_it != op_exec_info->cached_dtypes.end()) {
3102 return cached_it->second;
3103 }
3104
3105 auto it = op_exec_info->attr_to_inputs_map->find(attr);
3106 if (it == op_exec_info->attr_to_inputs_map->end()) {
3107 // No other inputs - this should never happen.
3108 return tensorflow::DT_INVALID;
3109 }
3110
3111 for (const auto& input_info : it->second) {
3112 PyObject* item = PyTuple_GET_ITEM(
3113 op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
3114 if (input_info.is_list) {
3115 tensorflow::Safe_PyObjectPtr fast_item(
3116 PySequence_Fast(item, "Unable to allocate"));
3117 int len = PySequence_Fast_GET_SIZE(fast_item.get());
3118 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3119 for (int i = 0; i < len; i++) {
3120 auto dtype = MaybeGetDType(fast_item_array[i]);
3121 if (dtype != tensorflow::DT_INVALID) return dtype;
3122 }
3123 } else {
3124 auto dtype = MaybeGetDType(item);
3125 if (dtype != tensorflow::DT_INVALID) return dtype;
3126 }
3127 }
3128
3129 auto default_it = op_exec_info->default_dtypes->find(attr);
3130 if (default_it != op_exec_info->default_dtypes->end()) {
3131 return default_it->second;
3132 }
3133
3134 return tensorflow::DT_INVALID;
3135 }
3136
3137 PyObject* CopySequenceSettingIndicesToNull(
3138 PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
3139 tensorflow::Safe_PyObjectPtr fast_seq(
3140 PySequence_Fast(seq, "unable to allocate"));
3141 int len = PySequence_Fast_GET_SIZE(fast_seq.get());
3142 PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
3143 PyObject* result = PyTuple_New(len);
3144 for (int i = 0; i < len; i++) {
3145 PyObject* item;
3146 if (indices.find(i) != indices.end()) {
3147 item = Py_None;
3148 } else {
3149 item = fast_seq_array[i];
3150 }
3151 Py_INCREF(item);
3152 PyTuple_SET_ITEM(result, i, item);
3153 }
3154 return result;
3155 }
3156
3157 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
3158 PyObject* results,
3159 PyObject* forward_pass_name_scope = nullptr) {
3160 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
3161 if (PyErr_Occurred()) return nullptr;
3162 std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
3163 if (PyErr_Occurred()) return nullptr;
3164
3165 bool should_record = false;
3166 for (TFE_Py_Tape* tape : SafeTapeSet()) {
3167 if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
3168 should_record = true;
3169 break;
3170 }
3171 }
3172 if (!should_record) {
3173 for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
3174 if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
3175 should_record = true;
3176 break;
3177 }
3178 }
3179 }
3180 if (!should_record) Py_RETURN_NONE;
3181
3182 string c_op_name = TFE_GetPythonString(op_name);
3183
3184 PyObject* op_outputs;
3185 bool op_outputs_tuple_created = false;
3186
3187 if (const auto unused_output_indices =
3188 OpGradientUnusedOutputIndices(c_op_name)) {
3189 if (unused_output_indices->empty()) {
3190 op_outputs = Py_None;
3191 } else {
3192 op_outputs_tuple_created = true;
3193 op_outputs =
3194 CopySequenceSettingIndicesToNull(results, *unused_output_indices);
3195 }
3196 } else {
3197 op_outputs = results;
3198 }
3199
3200 PyObject* op_inputs;
3201 bool op_inputs_tuple_created = false;
3202
3203 if (const auto unused_input_indices =
3204 OpGradientUnusedInputIndices(c_op_name)) {
3205 if (unused_input_indices->empty()) {
3206 op_inputs = Py_None;
3207 } else {
3208 op_inputs_tuple_created = true;
3209 op_inputs =
3210 CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
3211 }
3212 } else {
3213 op_inputs = inputs;
3214 }
3215
3216 tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
3217 [op_name, attrs, inputs, results](
3218 const std::vector<PyObject*>& input_tangents,
3219 std::vector<PyObject*>* output_tangents, bool use_batch) {
3220 return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
3221 output_tangents, use_batch);
3222 });
3223 tensorflow::eager::ForwardFunction<PyObject>* forward_function;
3224 if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
3225 c_op_name == "If" || c_op_name == "StatelessIf") {
3226 // Control flow contains non-hashable attributes. Handling them in Python is
3227 // a headache, so instead we'll stay as close to GradientTape's handling as
3228 // possible (a null forward function means the accumulator forwards to a
3229 // tape).
3230 //
3231 // This is safe to do since we'll only see control flow when graph building,
3232 // in which case we can rely on pruning.
3233 forward_function = nullptr;
3234 } else {
3235 forward_function = &py_forward_function;
3236 }
3237
3238 PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
3239
3240 if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
3241
3242 TapeSetRecordOperation(
3243 op_name, inputs, results, input_ids, input_dtypes,
3244 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3245 forward_pass_name_scope]() {
3246 Py_INCREF(op_name);
3247 Py_INCREF(attrs);
3248 Py_INCREF(num_inputs);
3249 Py_INCREF(op_inputs);
3250 Py_INCREF(op_outputs);
3251 Py_INCREF(forward_pass_name_scope);
3252 PyBackwardFunction* function = new PyBackwardFunction(
3253 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3254 forward_pass_name_scope](
3255 PyObject* output_grads,
3256 const std::vector<tensorflow::int64>& unneeded_gradients) {
3257 if (PyErr_Occurred()) {
3258 return static_cast<PyObject*>(nullptr);
3259 }
3260 tensorflow::Safe_PyObjectPtr skip_input_indices;
3261 if (!unneeded_gradients.empty()) {
3262 skip_input_indices.reset(
3263 PyTuple_New(unneeded_gradients.size()));
3264 for (int i = 0; i < unneeded_gradients.size(); i++) {
3265 PyTuple_SET_ITEM(
3266 skip_input_indices.get(), i,
3267 GetPythonObjectFromInt(unneeded_gradients[i]));
3268 }
3269 } else {
3270 Py_INCREF(Py_None);
3271 skip_input_indices.reset(Py_None);
3272 }
3273 tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3274 "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3275 output_grads, skip_input_indices.get(),
3276 forward_pass_name_scope));
3277
3278 tensorflow::Safe_PyObjectPtr result(
3279 PyObject_CallObject(gradient_function, callback_args.get()));
3280
3281 if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
3282
3283 return tensorflow::swig::Flatten(result.get());
3284 });
3285 return function;
3286 },
3287 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3288 forward_pass_name_scope](PyBackwardFunction* backward_function) {
3289 Py_DECREF(op_name);
3290 Py_DECREF(attrs);
3291 Py_DECREF(num_inputs);
3292 Py_DECREF(op_inputs);
3293 Py_DECREF(op_outputs);
3294 Py_DECREF(forward_pass_name_scope);
3295
3296 delete backward_function;
3297 },
3298 forward_function);
3299
3300 Py_DECREF(num_inputs);
3301 if (op_outputs_tuple_created) Py_DECREF(op_outputs);
3302 if (op_inputs_tuple_created) Py_DECREF(op_inputs);
3303
3304 if (PyErr_Occurred()) {
3305 return nullptr;
3306 }
3307
3308 Py_RETURN_NONE;
3309 }
3310
3311 void MaybeNotifyVariableAccessed(PyObject* input) {
3312 DCHECK(CheckResourceVariable(input));
3313 DCHECK(PyObject_HasAttrString(input, "_trainable"));
3314
3315 tensorflow::Safe_PyObjectPtr trainable(
3316 PyObject_GetAttrString(input, "_trainable"));
3317 if (trainable.get() == Py_False) return;
3318 TFE_Py_TapeVariableAccessed(input);
3319 TFE_Py_VariableWatcherVariableAccessed(input);
3320 }
3321
3322 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
3323 PyObject* input, tensorflow::Safe_PyObjectPtr* output,
3324 TF_Status* status) {
3325 MaybeNotifyVariableAccessed(input);
3326
3327 TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
3328 auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
3329 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3330
3331 TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
3332 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3333
3334 // Set dtype
3335 DCHECK(PyObject_HasAttrString(input, "_dtype"));
3336 tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
3337 int value;
3338 if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
3339 return false;
3340 }
3341 TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
3342
3343 // Get handle
3344 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
3345 if (!EagerTensor_CheckExact(handle.get())) return false;
3346 TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
3347 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3348
3349 int num_retvals = 1;
3350 TFE_TensorHandle* output_handle;
3351 TFE_Execute(op, &output_handle, &num_retvals, status);
3352 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3353
3354 // Always create the py object (and correctly DECREF it) from the returned
3355 // value, else the data will leak.
3356 output->reset(EagerTensorFromHandle(output_handle));
3357
3358 // TODO(nareshmodi): Should we run post exec callbacks here?
3359 if (parent_op_exec_info.run_gradient_callback) {
3360 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
3361 PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
3362
3363 tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
3364 Py_INCREF(output->get()); // stay alive after since tuple steals.
3365 PyTuple_SET_ITEM(outputs.get(), 0, output->get());
3366
3367 tensorflow::Safe_PyObjectPtr op_string(
3368 GetPythonObjectFromString("ReadVariableOp"));
3369 if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
3370 outputs.get())) {
3371 return false;
3372 }
3373 }
3374
3375 return true;
3376 }
3377
3378 // Supports 3 cases at the moment:
3379 // i) input is an EagerTensor.
3380 // ii) input is a ResourceVariable - in this case, the is_variable param is
3381 // set to true.
3382 // iii) input is an arbitrary python list/tuple (note, this handling doesn't
3383 // support packing).
3384 //
3385 // NOTE: dtype_hint_getter must *always* return a PyObject that can be
3386 // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
3387 // increfs Py_None).
3388 //
3389 // NOTE: This function sets a python error directly, and returns false.
3390 // TF_Status is only passed since we don't want to have to reallocate it.
3391 bool ConvertToTensor(
3392 const FastPathOpExecInfo& op_exec_info, PyObject* input,
3393 tensorflow::Safe_PyObjectPtr* output_handle,
3394 // This gets a hint for this particular input.
3395 const std::function<tensorflow::DataType()>& dtype_hint_getter,
3396 // This sets the dtype after conversion is complete.
3397 const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
3398 TF_Status* status) {
3399 if (EagerTensor_CheckExact(input)) {
3400 Py_INCREF(input);
3401 output_handle->reset(input);
3402 return true;
3403 } else if (CheckResourceVariable(input)) {
3404 return ReadVariableOp(op_exec_info, input, output_handle, status);
3405 }
3406
3407 // The hint comes from a supposedly similarly typed tensor.
3408 tensorflow::DataType dtype_hint = dtype_hint_getter();
3409
3410 TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
3411 op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
3412 if (handle == nullptr) {
3413 return MaybeRaiseExceptionFromTFStatus(status, nullptr);
3414 }
3415
3416 output_handle->reset(EagerTensorFromHandle(handle));
3417 dtype_setter(
3418 static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
3419
3420 return true;
3421 }
3422
3423 // Adds input and type attr to the op, and to the list of flattened
3424 // inputs/attrs.
3425 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
3426 const bool add_type_attr,
3427 const tensorflow::OpDef::ArgDef& input_arg,
3428 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
3429 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
3430 TFE_Op* op, TF_Status* status) {
3431 // py_eager_tensor's ownership is transferred to flattened_inputs if it is
3432 // required, else the object is destroyed and DECREF'd when the object goes
3433 // out of scope in this function.
3434 tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
3435
3436 if (!ConvertToTensor(
3437 *op_exec_info, input, &py_eager_tensor,
3438 [&]() {
3439 if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
3440 return input_arg.type();
3441 }
3442 return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
3443 },
3444 [&](const tensorflow::DataType dtype) {
3445 op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
3446 },
3447 status)) {
3448 return false;
3449 }
3450
3451 TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
3452
3453 if (add_type_attr && !input_arg.type_attr().empty()) {
3454 auto dtype = TFE_TensorHandleDataType(input_handle);
3455 TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
3456 if (flattened_attrs != nullptr) {
3457 flattened_attrs->emplace_back(
3458 GetPythonObjectFromString(input_arg.type_attr()));
3459 flattened_attrs->emplace_back(PyLong_FromLong(dtype));
3460 }
3461 }
3462
3463 if (flattened_inputs != nullptr) {
3464 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3465 }
3466
3467 TFE_OpAddInput(op, input_handle, status);
3468 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3469 return false;
3470 }
3471
3472 return true;
3473 }
3474
3475 const char* GetDeviceName(PyObject* py_device_name) {
3476 if (py_device_name != Py_None) {
3477 return TFE_GetPythonString(py_device_name);
3478 }
3479 return nullptr;
3480 }
3481
3482 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
3483 if (!PySequence_Check(seq)) {
3484 PyErr_SetString(PyExc_TypeError,
3485 Printf("expected a sequence for attr %s, got %s instead",
3486 attr_name.data(), seq->ob_type->tp_name)
3487 .data());
3488
3489 return false;
3490 }
3491 if (PyArray_Check(seq) &&
3492 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
3493 PyErr_SetString(PyExc_ValueError,
3494 Printf("expected a sequence for attr %s, got an ndarray "
3495 "with rank %d instead",
3496 attr_name.data(),
3497 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
3498 .data());
3499 return false;
3500 }
3501 return true;
3502 }
3503
3504 bool RunCallbacks(
3505 const FastPathOpExecInfo& op_exec_info, PyObject* args,
3506 int num_inferred_attrs,
3507 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
3508 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
3509 PyObject* flattened_result) {
3510 DCHECK(op_exec_info.run_callbacks);
3511
3512 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
3513 for (int i = 0; i < flattened_inputs.size(); i++) {
3514 PyObject* input = flattened_inputs[i].get();
3515 Py_INCREF(input);
3516 PyTuple_SET_ITEM(inputs.get(), i, input);
3517 }
3518
3519 int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
3520 int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
3521 tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
3522
3523 for (int i = 0; i < num_non_inferred_attrs; i++) {
3524 auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
3525 Py_INCREF(attr);
3526 PyTuple_SET_ITEM(attrs.get(), i, attr);
3527 }
3528
3529 for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
3530 PyObject* attr_or_name =
3531 flattened_attrs.at(i - num_non_inferred_attrs).get();
3532 Py_INCREF(attr_or_name);
3533 PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
3534 }
3535
3536 if (op_exec_info.run_gradient_callback) {
3537 if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
3538 flattened_result)) {
3539 return false;
3540 }
3541 }
3542
3543 if (op_exec_info.run_post_exec_callbacks) {
3544 tensorflow::Safe_PyObjectPtr callback_args(
3545 Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
3546 flattened_result, op_exec_info.name));
3547 for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
3548 PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
3549 if (!PyCallable_Check(callback_fn)) {
3550 PyErr_SetString(
3551 PyExc_TypeError,
3552 Printf("expected a function for "
3553 "post execution callback in index %ld, got %s instead",
3554 i, callback_fn->ob_type->tp_name)
3555 .c_str());
3556 return false;
3557 }
3558 PyObject* callback_result =
3559 PyObject_CallObject(callback_fn, callback_args.get());
3560 if (!callback_result) {
3561 return false;
3562 }
3563 Py_DECREF(callback_result);
3564 }
3565 }
3566
3567 return true;
3568 }
3569
3570 } // namespace
3571
3572 PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
3573 tensorflow::profiler::TraceMe activity(
3574 "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
3575 Py_ssize_t args_size = PyTuple_GET_SIZE(args);
3576 if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
3577 PyErr_SetString(
3578 PyExc_ValueError,
3579 Printf("There must be at least %d items in the input tuple.",
3580 FAST_PATH_EXECUTE_ARG_INPUT_START)
3581 .c_str());
3582 return nullptr;
3583 }
3584
3585 FastPathOpExecInfo op_exec_info;
3586
3587 PyObject* py_eager_context =
3588 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
3589
3590 // TODO(edoper): Use interned string here
3591 PyObject* eager_context_handle =
3592 PyObject_GetAttrString(py_eager_context, "_context_handle");
3593
3594 TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
3595 PyCapsule_GetPointer(eager_context_handle, nullptr));
3596 op_exec_info.ctx = ctx;
3597 op_exec_info.args = args;
3598
3599 if (ctx == nullptr) {
3600 // The context hasn't been initialized. It will be in the slow path.
3601 RaiseFallbackException(
3602 "This function does not handle the case of the path where "
3603 "all inputs are not already EagerTensors.");
3604 return nullptr;
3605 }
3606
3607 auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
3608 if (tld == nullptr) {
3609 return nullptr;
3610 }
3611 op_exec_info.device_name = GetDeviceName(tld->device_name.get());
3612 op_exec_info.callbacks = tld->op_callbacks.get();
3613
3614 op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
3615 op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
3616
3617 // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
3618 // (similar to benchmark_tf_gradient_function_*). Also consider using an
3619 // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
3620 // point out problems with heap allocs.
3621 op_exec_info.run_gradient_callback =
3622 !*ThreadTapeIsStopped() && HasAccumulatorOrTape();
3623 op_exec_info.run_post_exec_callbacks =
3624 op_exec_info.callbacks != Py_None &&
3625 PyList_Size(op_exec_info.callbacks) > 0;
3626 op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
3627 op_exec_info.run_post_exec_callbacks;
3628
3629 TF_Status* status = GetStatus();
3630 const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
3631 if (op_name == nullptr) {
3632 PyErr_SetString(PyExc_TypeError,
3633 Printf("expected a string for op_name, got %s instead",
3634 op_exec_info.op_name->ob_type->tp_name)
3635 .c_str());
3636 return nullptr;
3637 }
3638
3639 TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
3640
3641 auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
3642 ReturnStatus(status);
3643 ReturnOp(ctx, op);
3644 });
3645
3646 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3647 return nullptr;
3648 }
3649
3650 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
3651 tensorflow::StackTrace::kStackTraceInitialSize));
3652
3653 const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
3654 if (op_def == nullptr) return nullptr;
3655
3656 if (args_size <
3657 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
3658 PyErr_SetString(
3659 PyExc_ValueError,
3660 Printf("Tuple size smaller than intended. Expected to be at least %d, "
3661 "was %ld",
3662 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3663 args_size)
3664 .c_str());
3665 return nullptr;
3666 }
3667
3668 if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
3669 RaiseFallbackException(
3670 "This function does not handle the case of the path where "
3671 "all inputs are not already EagerTensors.");
3672 return nullptr;
3673 }
3674
3675 op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
3676 op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
3677
3678 // Mapping of attr name to size - used to calculate the number of values
3679 // to be expected by the TFE_Execute run.
3680 tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
3681
3682 // Set non-inferred attrs, including setting defaults if the attr is passed in
3683 // as None.
3684 for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
3685 i < args_size; i += 2) {
3686 PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
3687 const char* attr_name = TFE_GetPythonString(py_attr_name);
3688 PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
3689
3690 // Not creating an index since most of the time there are not more than a
3691 // few attrs.
3692 // TODO(nareshmodi): Maybe include the index as part of the
3693 // OpRegistrationData.
3694 for (const auto& attr : op_def->attr()) {
3695 if (tensorflow::StringPiece(attr_name) == attr.name()) {
3696 SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
3697 &attr_list_sizes, status);
3698
3699 if (!status->status.ok()) {
3700 VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
3701 << "\" since we are unable to set the value for attr \""
3702 << attr.name() << "\" due to: " << TF_Message(status);
3703 RaiseFallbackException(TF_Message(status));
3704 return nullptr;
3705 }
3706
3707 break;
3708 }
3709 }
3710 }
3711
3712 // Flat attrs and inputs as required by the record_gradient call. The attrs
3713 // here only contain inferred attrs (non-inferred attrs are added directly
3714 // from the input args).
3715 // All items in flattened_attrs and flattened_inputs contain
3716 // Safe_PyObjectPtr - any time something steals a reference to this, it must
3717 // INCREF.
3718 // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
3719 // directly.
3720 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
3721 nullptr;
3722 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
3723 nullptr;
3724
3725 // TODO(nareshmodi): Encapsulate callbacks information into a struct.
3726 if (op_exec_info.run_callbacks) {
3727 flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3728 flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3729 }
3730
3731 // Add inferred attrs and inputs.
3732 // The following code might set duplicate type attrs. This will result in
3733 // the CacheKey for the generated AttrBuilder possibly differing from
3734 // those where the type attrs are correctly set. Inconsistent CacheKeys
3735 // for ops means that there might be unnecessarily duplicated kernels.
3736 // TODO(nareshmodi): Fix this.
3737 for (int i = 0; i < op_def->input_arg_size(); i++) {
3738 const auto& input_arg = op_def->input_arg(i);
3739
3740 PyObject* input =
3741 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
3742 if (!input_arg.number_attr().empty()) {
3743 // The item is a homogeneous list.
3744 if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
3745 tensorflow::Safe_PyObjectPtr fast_input(
3746 PySequence_Fast(input, "Could not parse sequence."));
3747 if (fast_input.get() == nullptr) {
3748 return nullptr;
3749 }
3750 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3751 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3752
3753 TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
3754 if (op_exec_info.run_callbacks) {
3755 flattened_attrs->emplace_back(
3756 GetPythonObjectFromString(input_arg.number_attr()));
3757 flattened_attrs->emplace_back(PyLong_FromLong(len));
3758 }
3759 attr_list_sizes[input_arg.number_attr()] = len;
3760
3761 if (len > 0) {
3762 // First item adds the type attr.
3763 if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
3764 flattened_attrs.get(), flattened_inputs.get(), op,
3765 status)) {
3766 return nullptr;
3767 }
3768
3769 for (Py_ssize_t j = 1; j < len; j++) {
3770 // Since the list is homogeneous, we don't need to re-add the attr.
3771 if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
3772 input_arg, nullptr /* flattened_attrs */,
3773 flattened_inputs.get(), op, status)) {
3774 return nullptr;
3775 }
3776 }
3777 }
3778 } else if (!input_arg.type_list_attr().empty()) {
3779 // The item is a heterogeneous list.
3780 if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
3781 return nullptr;
3782 }
3783 tensorflow::Safe_PyObjectPtr fast_input(
3784 PySequence_Fast(input, "Could not parse sequence."));
3785 if (fast_input.get() == nullptr) {
3786 return nullptr;
3787 }
3788 const string& attr_name = input_arg.type_list_attr();
3789 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3790 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3791 tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
3792 PyObject* py_attr_value = nullptr;
3793 if (op_exec_info.run_callbacks) {
3794 py_attr_value = PyTuple_New(len);
3795 }
3796 for (Py_ssize_t j = 0; j < len; j++) {
3797 PyObject* py_input = fast_input_array[j];
3798 tensorflow::Safe_PyObjectPtr py_eager_tensor;
3799 if (!ConvertToTensor(
3800 op_exec_info, py_input, &py_eager_tensor,
3801 []() { return tensorflow::DT_INVALID; },
3802 [](const tensorflow::DataType dtype) {}, status)) {
3803 return nullptr;
3804 }
3805
3806 TFE_TensorHandle* input_handle =
3807 EagerTensor_Handle(py_eager_tensor.get());
3808
3809 attr_value[j] = TFE_TensorHandleDataType(input_handle);
3810
3811 TFE_OpAddInput(op, input_handle, status);
3812 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3813 return nullptr;
3814 }
3815
3816 if (op_exec_info.run_callbacks) {
3817 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3818
3819 PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
3820 }
3821 }
3822 if (op_exec_info.run_callbacks) {
3823 flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
3824 flattened_attrs->emplace_back(py_attr_value);
3825 }
3826 TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
3827 attr_value.size());
3828 attr_list_sizes[attr_name] = len;
3829 } else {
3830 // The item is a single item.
3831 if (!AddInputToOp(&op_exec_info, input, true, input_arg,
3832 flattened_attrs.get(), flattened_inputs.get(), op,
3833 status)) {
3834 return nullptr;
3835 }
3836 }
3837 }
3838
3839 int64_t num_outputs = 0;
3840 for (int i = 0; i < op_def->output_arg_size(); i++) {
3841 const auto& output_arg = op_def->output_arg(i);
3842 int64_t delta = 1;
3843 if (!output_arg.number_attr().empty()) {
3844 delta = attr_list_sizes[output_arg.number_attr()];
3845 } else if (!output_arg.type_list_attr().empty()) {
3846 delta = attr_list_sizes[output_arg.type_list_attr()];
3847 }
3848 if (delta < 0) {
3849 RaiseFallbackException(
3850 "Attributes suggest that the size of an output list is less than 0");
3851 return nullptr;
3852 }
3853 num_outputs += delta;
3854 }
3855
3856 // If number of retvals is larger than int32, we error out.
3857 if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
3858 PyErr_SetString(
3859 PyExc_ValueError,
3860 Printf("Number of outputs is too big: %ld", num_outputs).c_str());
3861 return nullptr;
3862 }
3863 int num_retvals = num_outputs;
3864
3865 tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
3866
3867 Py_BEGIN_ALLOW_THREADS;
3868 TFE_Execute(op, retvals.data(), &num_retvals, status);
3869 Py_END_ALLOW_THREADS;
3870
3871 if (!status->status.ok()) {
3872 // Augment the status with the op_name for easier debugging similar to
3873 // TFE_Py_Execute.
3874 std::vector<tensorflow::StackFrame> stack_trace =
3875 status->status.stack_trace();
3876 status->status = tensorflow::Status(
3877 status->status.code(),
3878 tensorflow::strings::StrCat(
3879 TF_Message(status),
3880 " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]"),
3881 std::move(stack_trace));
3882
3883 MaybeRaiseExceptionFromTFStatus(status, nullptr);
3884 return nullptr;
3885 }
3886
3887 tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
3888 for (int i = 0; i < num_retvals; ++i) {
3889 PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
3890 }
3891
3892 if (op_exec_info.run_callbacks) {
3893 if (!RunCallbacks(
3894 op_exec_info, args,
3895 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3896 *flattened_inputs, *flattened_attrs, flat_result.get())) {
3897 return nullptr;
3898 }
3899 }
3900
3901 // Unflatten results.
3902 if (op_def->output_arg_size() == 0) {
3903 Py_RETURN_NONE;
3904 }
3905
3906 if (op_def->output_arg_size() == 1) {
3907 if (!op_def->output_arg(0).number_attr().empty() ||
3908 !op_def->output_arg(0).type_list_attr().empty()) {
3909 return flat_result.release();
3910 } else {
3911 auto* result = PyList_GET_ITEM(flat_result.get(), 0);
3912 Py_INCREF(result);
3913 return result;
3914 }
3915 }
3916
3917 // Correctly output the results that are made into a namedtuple.
3918 PyObject* result = PyList_New(op_def->output_arg_size());
3919 int flat_result_index = 0;
3920 for (int i = 0; i < op_def->output_arg_size(); i++) {
3921 if (!op_def->output_arg(i).number_attr().empty()) {
3922 int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
3923 PyObject* inner_list = PyList_New(list_length);
3924 for (int j = 0; j < list_length; j++) {
3925 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3926 Py_INCREF(obj);
3927 PyList_SET_ITEM(inner_list, j, obj);
3928 }
3929 PyList_SET_ITEM(result, i, inner_list);
3930 } else if (!op_def->output_arg(i).type_list_attr().empty()) {
3931 int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
3932 PyObject* inner_list = PyList_New(list_length);
3933 for (int j = 0; j < list_length; j++) {
3934 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3935 Py_INCREF(obj);
3936 PyList_SET_ITEM(inner_list, j, obj);
3937 }
3938 PyList_SET_ITEM(result, i, inner_list);
3939 } else {
3940 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3941 Py_INCREF(obj);
3942 PyList_SET_ITEM(result, i, obj);
3943 }
3944 }
3945 return result;
3946 }
3947
3948 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
3949 PyObject* attrs, PyObject* results,
3950 PyObject* forward_pass_name_scope) {
3951 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
3952 Py_RETURN_NONE;
3953 }
3954
3955 return RecordGradient(op_name, inputs, attrs, results,
3956 forward_pass_name_scope);
3957 }
3958
3959 namespace {
3960 const char kTensor[] = "T";
3961 const char kList[] = "L";
3962 const char kListEnd[] = "l";
3963 const char kTuple[] = "U";
3964 const char kTupleEnd[] = "u";
3965 const char kDIter[] = "I";
3966 const char kDict[] = "D";
3967 const char kRaw[] = "R";
3968 const char kShape[] = "s";
3969 const char kShapeDelim[] = "-";
3970 const char kDType[] = "d";
3971 const char kNone[] = "n";
3972 const char kCompositeTensor[] = "C";
3973 const char kAttrs[] = "A";
3974 const char kAttrsEnd[] = "a";
3975
3976 struct EncodeResult {
3977 string str;
3978 std::vector<PyObject*> objects;
3979
3980 PyObject* ToPyTuple() {
3981 PyObject* result = PyTuple_New(2);
3982
3983 PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str));
3984
3985 if (objects.empty()) {
3986 Py_INCREF(Py_None);
3987 PyTuple_SET_ITEM(result, 1, Py_None);
3988 } else {
3989 PyObject* objects_tuple = PyTuple_New(objects.size());
3990
3991 for (int i = 0; i < objects.size(); i++) {
3992 PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
3993 }
3994
3995 PyTuple_SET_ITEM(result, 1, objects_tuple);
3996 }
3997
3998 return result;
3999 }
4000 };
4001
4002 tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
4003 bool include_tensor_ranks_only,
4004 EncodeResult* result) {
4005 if (EagerTensor_CheckExact(arg)) {
4006 tensorflow::ImmediateExecutionTensorHandle* handle =
4007 tensorflow::unwrap(EagerTensor_Handle(arg));
4008
4009 absl::StrAppend(&result->str, kDType,
4010 static_cast<tensorflow::DataType>(handle->DataType()));
4011 absl::StrAppend(&result->str, kShape);
4012
4013 int num_dims;
4014 tensorflow::Status status = handle->NumDims(&num_dims);
4015 if (!status.ok()) return status;
4016
4017 if (include_tensor_ranks_only) {
4018 absl::StrAppend(&result->str, num_dims);
4019 } else {
4020 for (int i = 0; i < num_dims; ++i) {
4021 int64_t dim_size;
4022 status = handle->Dim(i, &dim_size);
4023 if (!status.ok()) return status;
4024 absl::StrAppend(&result->str, dim_size, kShapeDelim);
4025 }
4026 }
4027 return tensorflow::Status::OK();
4028 }
4029
4030 tensorflow::Safe_PyObjectPtr dtype_object(
4031 PyObject_GetAttrString(arg, "dtype"));
4032
4033 if (dtype_object == nullptr) {
4034 return tensorflow::errors::InvalidArgument(
4035 "ops.Tensor object doesn't have dtype() attr.");
4036 }
4037
4038 tensorflow::Safe_PyObjectPtr dtype_enum(
4039 PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
4040
4041 if (dtype_enum == nullptr) {
4042 return tensorflow::errors::InvalidArgument(
4043 "ops.Tensor's dtype object doesn't have _type_enum() attr.");
4044 }
4045
4046 tensorflow::DataType dtype =
4047 static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
4048
4049 absl::StrAppend(&result->str, kDType, dtype);
4050
4051 static char _shape_tuple[] = "_shape_tuple";
4052 tensorflow::Safe_PyObjectPtr shape_tuple(
4053 PyObject_CallMethod(arg, _shape_tuple, nullptr));
4054
4055 if (shape_tuple == nullptr) {
4056 return tensorflow::errors::InvalidArgument(
4057 "ops.Tensor object doesn't have _shape_tuple() method.");
4058 }
4059
4060 if (shape_tuple.get() == Py_None) {
4061 // Unknown shape, encode that directly.
4062 absl::StrAppend(&result->str, kNone);
4063 return tensorflow::Status::OK();
4064 }
4065
4066 absl::StrAppend(&result->str, kShape);
4067 tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
4068 shape_tuple.get(), "shape_tuple didn't return a sequence"));
4069
4070 int len = PySequence_Fast_GET_SIZE(shape_seq.get());
4071 PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get());
4072
4073 if (include_tensor_ranks_only) {
4074 absl::StrAppend(&result->str, len);
4075 } else {
4076 for (int i = 0; i < len; ++i) {
4077 PyObject* item = shape_seq_array[i];
4078 if (item == Py_None) {
4079 absl::StrAppend(&result->str, kNone);
4080 } else {
4081 absl::StrAppend(&result->str, MakeInt(item));
4082 }
4083 }
4084 }
4085 return tensorflow::Status::OK();
4086 }
4087
4088 tensorflow::Status TFE_Py_EncodeArgHelperInternal(
4089 PyObject* arg, bool include_tensor_ranks_only, std::vector<int>& res_vec,
4090 absl::flat_hash_map<int, int>& res_map, int& cur_res, EncodeResult* result);
4091
4092 // This function doesn't set the type of sequence before
4093 tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
4094 const char* end_type,
4095 bool include_tensor_ranks_only,
4096 std::vector<int>& res_vec,
4097 absl::flat_hash_map<int, int>& res_map,
4098 int& cur_res, EncodeResult* result) {
4099 tensorflow::Safe_PyObjectPtr arg_seq(
4100 PySequence_Fast(arg, "unable to create seq from list/tuple"));
4101
4102 absl::StrAppend(&result->str, type);
4103 int len = PySequence_Fast_GET_SIZE(arg_seq.get());
4104 PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get());
4105 for (int i = 0; i < len; ++i) {
4106 PyObject* item = arg_seq_array[i];
4107 if (item == Py_None) {
4108 absl::StrAppend(&result->str, kNone);
4109 } else {
4110 TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelperInternal(
4111 item, include_tensor_ranks_only, res_vec, res_map, cur_res, result));
4112 }
4113 }
4114 absl::StrAppend(&result->str, end_type);
4115
4116 return tensorflow::Status::OK();
4117 }
4118
4119 void UpdateResourceCount(int res_id, std::vector<int>& res_vec,
4120 absl::flat_hash_map<int, int>& res_map, int& cur_res) {
4121 const auto& it = res_map.find(res_id);
4122 if (it == res_map.end()) {
4123 res_map[res_id] = cur_res;
4124 res_vec.push_back(cur_res);
4125 ++cur_res;
4126 } else {
4127 res_vec.push_back(it->second);
4128 }
4129 }
4130
4131 tensorflow::Status TFE_Py_EncodeArgHelperInternal(
4132 PyObject* arg, bool include_tensor_ranks_only, std::vector<int>& res_vec,
4133 absl::flat_hash_map<int, int>& res_map, int& cur_res,
4134 EncodeResult* result) {
4135 if (tensorflow::swig::IsTensor(arg)) {
4136 absl::StrAppend(&result->str, kTensor);
4137 TF_RETURN_IF_ERROR(
4138 TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
4139 } else if (tensorflow::swig::IsOwnedIterator(arg)) {
4140 // TODO(jiaweix): distinguish other resource types
4141 // Similar to IsCompositeTensor below, plus resource id
4142 PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
4143 if (type_spec == nullptr) {
4144 return tensorflow::errors::InvalidArgument(
4145 "Error while reading OwnedIterator._type_spec.");
4146 }
4147 result->objects.push_back(type_spec);
4148
4149 // Add resource tracking
4150 tensorflow::Safe_PyObjectPtr itr_res(
4151 PyObject_GetAttrString(arg, "_iterator_resource"));
4152 if (itr_res == nullptr) {
4153 return tensorflow::errors::InvalidArgument(
4154 "Error while reading Dataset iterator resource.");
4155 }
4156 // OwnedIterator does not always have a unique resource id,
4157 // because a Dataset object is not required for OwnedIterator.__init__.
4158 // As a result we check whether '_iterator_resource' is a Tensor.
4159 if (tensorflow::swig::IsTensor(itr_res.get())) {
4160 absl::StrAppend(&result->str, kDIter);
4161 tensorflow::Safe_PyObjectPtr p_res_id(
4162 PyObject_GetAttrString(itr_res.get(), "_id"));
4163 if (p_res_id == nullptr) {
4164 return tensorflow::errors::InvalidArgument(
4165 "Error while reading Dataset iterator resouce id.");
4166 }
4167 int res_id = PyLong_AsSize_t(p_res_id.get());
4168 if (res_id < 0) {
4169 return tensorflow::errors::InvalidArgument("PyLong_AsSize_t failure");
4170 }
4171 UpdateResourceCount(res_id, res_vec, res_map, cur_res);
4172 } else {
4173 // If '_iterator_resource' is not a Tensor, there is no resource id.
4174 // Instead we treat it the same way as a CompositeTensor
4175 absl::StrAppend(&result->str, kCompositeTensor);
4176 }
4177 } else if (PyList_Check(arg)) {
4178 TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, kListEnd,
4179 include_tensor_ranks_only, res_vec,
4180 res_map, cur_res, result));
4181 } else if (tensorflow::swig::IsTuple(arg)) {
4182 TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, kTupleEnd,
4183 include_tensor_ranks_only, res_vec,
4184 res_map, cur_res, result));
4185 } else if (tensorflow::swig::IsMapping(arg)) {
4186 tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
4187 if (PyList_Sort(keys.get()) == -1) {
4188 return tensorflow::errors::Internal("Unable to sort keys");
4189 }
4190
4191 absl::StrAppend(&result->str, kDict);
4192 int len = PyList_Size(keys.get());
4193
4194 for (int i = 0; i < len; i++) {
4195 PyObject* key = PyList_GetItem(keys.get(), i);
4196 TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelperInternal(
4197 key, include_tensor_ranks_only, res_vec, res_map, cur_res, result));
4198 tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
4199 TF_RETURN_IF_ERROR(
4200 TFE_Py_EncodeArgHelperInternal(value.get(), include_tensor_ranks_only,
4201 res_vec, res_map, cur_res, result));
4202 }
4203 } else if (tensorflow::swig::IsCompositeTensor(arg)) {
4204 absl::StrAppend(&result->str, kCompositeTensor);
4205
4206 // Add the typespec to the list of objects. (Do *not* use a weakref,
4207 // since the type spec is often a temporary object.)
4208 PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
4209 if (type_spec == nullptr) {
4210 return tensorflow::errors::InvalidArgument(
4211 "Error while reading CompositeTensor._type_spec.");
4212 }
4213 result->objects.push_back(type_spec);
4214 } else if (tensorflow::swig::IsTypeSpec(arg)) {
4215 // Add the typespec (not a weakref) in case it's a temporary object.
4216 absl::StrAppend(&result->str, kRaw);
4217 Py_INCREF(arg);
4218 result->objects.push_back(arg);
4219 } else if (tensorflow::swig::IsAttrs(arg)) {
4220 absl::StrAppend(&result->str, kAttrs);
4221 tensorflow::Safe_PyObjectPtr attrs(
4222 PyObject_GetAttrString(arg, "__attrs_attrs__"));
4223 tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get()));
4224 for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item;
4225 item.reset(PyIter_Next(iter.get()))) {
4226 tensorflow::Safe_PyObjectPtr name(
4227 PyObject_GetAttrString(item.get(), "name"));
4228 tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get()));
4229 TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelperInternal(
4230 attr_arg.get(), include_tensor_ranks_only, res_vec, res_map, cur_res,
4231 result));
4232 }
4233 absl::StrAppend(&result->str, kAttrsEnd);
4234 } else {
4235 PyObject* object = PyWeakref_NewRef(arg, nullptr);
4236
4237 if (object == nullptr) {
4238 PyErr_Clear();
4239
4240 object = arg;
4241 Py_INCREF(object);
4242 }
4243
4244 absl::StrAppend(&result->str, kRaw);
4245 result->objects.push_back(object);
4246 }
4247
4248 return tensorflow::Status::OK();
4249 }
4250
4251 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
4252 bool include_tensor_ranks_only,
4253 EncodeResult* result) {
4254 std::vector<int> res_vec;
4255 absl::flat_hash_map<int, int> res_map;
4256 int cur_res = 0;
4257 auto status = TFE_Py_EncodeArgHelperInternal(
4258 arg, include_tensor_ranks_only, res_vec, res_map, cur_res, result);
4259
4260 // Add 'encoding' of resources
4261 std::string str_resource_encoding = "";
4262 for (auto&& i : res_vec) {
4263 str_resource_encoding.append(std::to_string(i));
4264 str_resource_encoding.append("_");
4265 }
4266 if (!str_resource_encoding.empty()) {
4267 result->objects.push_back(
4268 PyUnicode_FromString(str_resource_encoding.c_str()));
4269 }
4270
4271 return status;
4272 }
4273
4274 } // namespace
4275
4276 // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
4277 // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
4278 // are used for both performance reasons, as much TensorFlow code specializes
4279 // on known shapes to produce slimmer graphs, and correctness, as some
4280 // high-level APIs require shapes to be fully-known.
4281 //
4282 // `include_tensor_ranks_only` allows caching on arguments excluding shape info,
4283 // so that a slow path using relaxed shape can rely on a cache key that excludes
4284 // shapes.
4285 PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
4286 EncodeResult result;
4287 const auto status =
4288 TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
4289 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
4290 return nullptr;
4291 }
4292
4293 return result.ToPyTuple();
4294 }
4295
4296 // A method prints incoming messages directly to Python's
4297 // stdout using Python's C API. This is necessary in Jupyter notebooks
4298 // and colabs where messages to the C stdout don't go to the notebook
4299 // cell outputs, but calls to Python's stdout do.
4300 void PrintToPythonStdout(const char* msg) {
4301 if (Py_IsInitialized()) {
4302 PyGILState_STATE py_threadstate;
4303 py_threadstate = PyGILState_Ensure();
4304
4305 string string_msg = msg;
4306 // PySys_WriteStdout truncates strings over 1000 bytes, so
4307 // we write the message in chunks small enough to not be truncated.
4308 int CHUNK_SIZE = 900;
4309 auto len = string_msg.length();
4310 for (int i = 0; i < len; i += CHUNK_SIZE) {
4311 PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
4312 }
4313
4314 // Force flushing to make sure print newlines aren't interleaved in
4315 // some colab environments
4316 PyRun_SimpleString("import sys; sys.stdout.flush()");
4317
4318 PyGILState_Release(py_threadstate);
4319 }
4320 }
4321
4322 // Register PrintToPythonStdout as a log listener, to allow
4323 // printing in colabs and jupyter notebooks to work.
4324 void TFE_Py_EnableInteractivePythonLogging() {
4325 static bool enabled_interactive_logging = false;
4326 if (!enabled_interactive_logging) {
4327 enabled_interactive_logging = true;
4328 TF_RegisterLogListener(PrintToPythonStdout);
4329 }
4330 }
4331
4332 namespace {
4333 // weak reference to Python Context object currently active
4334 PyObject* weak_eager_context = nullptr;
4335 } // namespace
4336
4337 PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
4338 Py_XDECREF(weak_eager_context);
4339 weak_eager_context = PyWeakref_NewRef(py_context, nullptr);
4340 if (weak_eager_context == nullptr) {
4341 return nullptr;
4342 }
4343 Py_RETURN_NONE;
4344 }
4345
4346 PyObject* GetPyEagerContext() {
4347 if (weak_eager_context == nullptr) {
4348 PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
4349 return nullptr;
4350 }
4351 PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context);
4352 if (py_context == Py_None) {
4353 PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed");
4354 return nullptr;
4355 }
4356 Py_INCREF(py_context);
4357 return py_context;
4358 }
4359
4360 namespace {
4361
4362 // Default values for thread_local_data fields.
4363 struct EagerContextThreadLocalDataDefaults {
4364 tensorflow::Safe_PyObjectPtr is_eager;
4365 tensorflow::Safe_PyObjectPtr device_spec;
4366 };
4367
4368 // Maps each py_eager_context object to its thread_local_data.
4369 //
4370 // Note: we need to use the python Context object as the key here (and not
4371 // its handle object), because the handle object isn't created until the
4372 // context is initialized; but thread_local_data is potentially accessed
4373 // before then.
4374 using EagerContextThreadLocalDataMap = absl::flat_hash_map<
4375 PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
4376 thread_local EagerContextThreadLocalDataMap*
4377 eager_context_thread_local_data_map = nullptr;
4378
4379 // Maps each py_eager_context object to default values.
4380 using EagerContextThreadLocalDataDefaultsMap =
4381 absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
4382 EagerContextThreadLocalDataDefaultsMap*
4383 eager_context_thread_local_data_defaults = nullptr;
4384
4385 } // namespace
4386
4387 namespace tensorflow {
4388
4389 void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
4390 PyObject* is_eager,
4391 PyObject* device_spec) {
4392 DCheckPyGilState();
4393 if (eager_context_thread_local_data_defaults == nullptr) {
4394 absl::LeakCheckDisabler disabler;
4395 eager_context_thread_local_data_defaults =
4396 new EagerContextThreadLocalDataDefaultsMap();
4397 }
4398 if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
4399 PyErr_SetString(PyExc_AssertionError,
4400 "MakeEagerContextThreadLocalData may not be called "
4401 "twice on the same eager Context object.");
4402 }
4403
4404 auto& defaults =
4405 (*eager_context_thread_local_data_defaults)[py_eager_context];
4406 Py_INCREF(is_eager);
4407 defaults.is_eager.reset(is_eager);
4408 Py_INCREF(device_spec);
4409 defaults.device_spec.reset(device_spec);
4410 }
4411
4412 EagerContextThreadLocalData* GetEagerContextThreadLocalData(
4413 PyObject* py_eager_context) {
4414 if (eager_context_thread_local_data_defaults == nullptr) {
4415 PyErr_SetString(PyExc_AssertionError,
4416 "MakeEagerContextThreadLocalData must be called "
4417 "before GetEagerContextThreadLocalData.");
4418 return nullptr;
4419 }
4420 auto defaults =
4421 eager_context_thread_local_data_defaults->find(py_eager_context);
4422 if (defaults == eager_context_thread_local_data_defaults->end()) {
4423 PyErr_SetString(PyExc_AssertionError,
4424 "MakeEagerContextThreadLocalData must be called "
4425 "before GetEagerContextThreadLocalData.");
4426 return nullptr;
4427 }
4428
4429 if (eager_context_thread_local_data_map == nullptr) {
4430 absl::LeakCheckDisabler disabler;
4431 eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
4432 }
4433 auto& thread_local_data =
4434 (*eager_context_thread_local_data_map)[py_eager_context];
4435
4436 if (!thread_local_data) {
4437 thread_local_data.reset(new EagerContextThreadLocalData());
4438
4439 Safe_PyObjectPtr is_eager(
4440 PyObject_CallFunctionObjArgs(defaults->second.is_eager.get(), nullptr));
4441 if (!is_eager) return nullptr;
4442 thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
4443
4444 #if PY_MAJOR_VERSION >= 3
4445 PyObject* scope_name = PyUnicode_FromString("");
4446 #else
4447 PyObject* scope_name = PyString_FromString("");
4448 #endif
4449 thread_local_data->scope_name.reset(scope_name);
4450
4451 #if PY_MAJOR_VERSION >= 3
4452 PyObject* device_name = PyUnicode_FromString("");
4453 #else
4454 PyObject* device_name = PyString_FromString("");
4455 #endif
4456 thread_local_data->device_name.reset(device_name);
4457
4458 Py_INCREF(defaults->second.device_spec.get());
4459 thread_local_data->device_spec.reset(defaults->second.device_spec.get());
4460
4461 Py_INCREF(Py_None);
4462 thread_local_data->function_call_options.reset(Py_None);
4463
4464 Py_INCREF(Py_None);
4465 thread_local_data->executor.reset(Py_None);
4466
4467 thread_local_data->op_callbacks.reset(PyList_New(0));
4468 }
4469 return thread_local_data.get();
4470 }
4471
4472 void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
4473 DCheckPyGilState();
4474 if (eager_context_thread_local_data_defaults) {
4475 eager_context_thread_local_data_defaults->erase(py_eager_context);
4476 }
4477 if (eager_context_thread_local_data_map) {
4478 eager_context_thread_local_data_map->erase(py_eager_context);
4479 }
4480 }
4481
4482 } // namespace tensorflow
4483