1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <atomic>
17 #include <cstring>
18 #include <unordered_map>
19
20 #include "absl/debugging/leak_check.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/c/c_api.h"
24 #include "tensorflow/c/c_api_internal.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/c/eager/tape.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_op_internal.h"
30 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
31 #include "tensorflow/c/tf_status.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/gtl/compactptrset.h"
36 #include "tensorflow/core/lib/gtl/flatmap.h"
37 #include "tensorflow/core/lib/gtl/flatset.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/lib/strings/stringprintf.h"
40 #include "tensorflow/core/platform/casts.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/status.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/util/managed_stack_trace.h"
47 #include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
48 #include "tensorflow/python/eager/pywrap_tensor.h"
49 #include "tensorflow/python/eager/pywrap_tfe.h"
50 #include "tensorflow/python/lib/core/py_util.h"
51 #include "tensorflow/python/lib/core/safe_ptr.h"
52 #include "tensorflow/python/util/stack_trace.h"
53 #include "tensorflow/python/util/util.h"
54
55 using tensorflow::Status;
56 using tensorflow::string;
57 using tensorflow::strings::Printf;
58
59 namespace {
60 // NOTE: Items are retrieved from and returned to these unique_ptrs, and they
61 // act as arenas. This is important if the same thread requests 2 items without
62 // releasing one.
63 // The following sequence of events on the same thread will still succeed:
64 // - GetOp <- Returns existing.
65 // - GetOp <- Allocates and returns a new pointer.
66 // - ReleaseOp <- Sets the item in the unique_ptr.
67 // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
68 // This occurs when a PyFunc kernel is run. This behavior makes it safe in that
69 // case, as well as the case where python decides to reuse the underlying
70 // C++ thread in 2 python threads case.
71 struct OpDeleter {
operator ()__anon12da81bc0111::OpDeleter72 void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
73 };
74 thread_local std::unordered_map<TFE_Context*,
75 std::unique_ptr<TFE_Op, OpDeleter>>
76 thread_local_eager_operation_map; // NOLINT
77 thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
78 nullptr;
79
ReleaseThreadLocalOp(TFE_Context * ctx)80 std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
81 auto it = thread_local_eager_operation_map.find(ctx);
82 if (it == thread_local_eager_operation_map.end()) {
83 return nullptr;
84 }
85 return std::move(it->second);
86 }
87
GetOp(TFE_Context * ctx,const char * op_or_function_name,const char * raw_device_name,TF_Status * status)88 TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
89 const char* raw_device_name, TF_Status* status) {
90 auto op = ReleaseThreadLocalOp(ctx);
91 if (!op) {
92 op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation()));
93 }
94 status->status =
95 tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name);
96 if (!status->status.ok()) {
97 op.reset();
98 }
99 return op.release();
100 }
101
ReturnOp(TFE_Context * ctx,TFE_Op * op)102 void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
103 if (op) {
104 tensorflow::unwrap(op)->Clear();
105 thread_local_eager_operation_map[ctx].reset(op);
106 }
107 }
108
ReleaseThreadLocalStatus()109 TF_Status* ReleaseThreadLocalStatus() {
110 if (thread_local_tf_status == nullptr) {
111 return nullptr;
112 }
113 return thread_local_tf_status.release();
114 }
115
116 struct InputInfo {
InputInfo__anon12da81bc0111::InputInfo117 InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
118
119 int i;
120 bool is_list = false;
121 };
122
123 // Takes in output gradients, returns input gradients.
124 typedef std::function<PyObject*(PyObject*,
125 const std::vector<tensorflow::int64>&)>
126 PyBackwardFunction;
127
128 using AttrToInputsMap =
129 tensorflow::gtl::FlatMap<string,
130 tensorflow::gtl::InlinedVector<InputInfo, 4>>;
131
GetAllAttrToInputsMaps()132 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
133 static auto* all_attr_to_input_maps =
134 new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
135 return all_attr_to_input_maps;
136 }
137
138 // This function doesn't use a lock, since we depend on the GIL directly.
GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef & op_def)139 AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
140 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
141 DCHECK(PyGILState_Check())
142 << "This function needs to hold the GIL when called.";
143 #endif
144 auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
145 auto* output =
146 tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
147 if (output != nullptr) {
148 return output;
149 }
150
151 std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
152
153 // Store a list of InputIndex -> List of corresponding inputs.
154 for (int i = 0; i < op_def.input_arg_size(); i++) {
155 if (!op_def.input_arg(i).type_attr().empty()) {
156 auto it = m->find(op_def.input_arg(i).type_attr());
157 if (it == m->end()) {
158 it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
159 }
160 it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
161 }
162 }
163
164 auto* retval = m.get();
165 (*all_attr_to_input_maps)[op_def.name()] = m.release();
166
167 return retval;
168 }
169
170 // This function doesn't use a lock, since we depend on the GIL directly.
171 tensorflow::gtl::FlatMap<
172 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
GetAllAttrToDefaultsMaps()173 GetAllAttrToDefaultsMaps() {
174 static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
175 string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
176 return all_attr_to_defaults_maps;
177 }
178
179 tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef & op_def)180 GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
181 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
182 DCHECK(PyGILState_Check())
183 << "This function needs to hold the GIL when called.";
184 #endif
185 auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
186 auto* output =
187 tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
188 if (output != nullptr) {
189 return output;
190 }
191
192 auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
193
194 for (const auto& attr : op_def.attr()) {
195 if (attr.type() == "type" && attr.has_default_value()) {
196 new_map->insert({attr.name(), attr.default_value().type()});
197 }
198 }
199
200 (*all_attr_to_defaults_maps)[op_def.name()] = new_map;
201
202 return new_map;
203 }
204
205 struct FastPathOpExecInfo {
206 TFE_Context* ctx;
207 const char* device_name;
208
209 bool run_callbacks;
210 bool run_post_exec_callbacks;
211 bool run_gradient_callback;
212
213 // The op name of the main op being executed.
214 PyObject* name;
215 // The op type name of the main op being executed.
216 PyObject* op_name;
217 PyObject* callbacks;
218
219 // All the args passed into the FastPathOpExecInfo.
220 PyObject* args;
221
222 // DTypes can come from another input that has the same attr. So build that
223 // map.
224 const AttrToInputsMap* attr_to_inputs_map;
225 const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
226 tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
227 };
228
229 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
230 bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
231 type* value) { \
232 if (check_fn(py_value)) { \
233 *value = static_cast<type>(parse_fn(py_value)); \
234 return true; \
235 } else { \
236 TF_SetStatus(status, TF_INVALID_ARGUMENT, \
237 tensorflow::strings::StrCat( \
238 "Expecting " #type " value for attr ", key, ", got ", \
239 py_value->ob_type->tp_name) \
240 .c_str()); \
241 return false; \
242 } \
243 }
244
245 #if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue,int,PyLong_Check,PyLong_AsLong)246 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
247 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLongLong)
248 #else
249 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
250 #endif
251 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
252 #undef PARSE_VALUE
253
254 #if PY_MAJOR_VERSION < 3
255 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
256 int64_t* value) {
257 if (PyInt_Check(py_value)) {
258 *value = static_cast<int64_t>(PyInt_AsLong(py_value));
259 return true;
260 } else if (PyLong_Check(py_value)) {
261 *value = static_cast<int64_t>(PyLong_AsLong(py_value));
262 return true;
263 }
264 TF_SetStatus(
265 status, TF_INVALID_ARGUMENT,
266 tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
267 ", got ", py_value->ob_type->tp_name)
268 .c_str());
269 return false;
270 }
271 #endif
272
TensorShapeNumDims(PyObject * value)273 Py_ssize_t TensorShapeNumDims(PyObject* value) {
274 const auto size = PySequence_Size(value);
275 if (size == -1) {
276 // TensorShape.__len__ raises an error in the scenario where the shape is an
277 // unknown, which needs to be cleared.
278 // TODO(nareshmodi): ensure that this is actually a TensorShape.
279 PyErr_Clear();
280 }
281 return size;
282 }
283
IsInteger(PyObject * py_value)284 bool IsInteger(PyObject* py_value) {
285 #if PY_MAJOR_VERSION >= 3
286 return PyLong_Check(py_value);
287 #else
288 return PyInt_Check(py_value) || PyLong_Check(py_value);
289 #endif
290 }
291
292 // This function considers a Dimension._value of None to be valid, and sets the
293 // value to be -1 in that case.
ParseDimensionValue(const string & key,PyObject * py_value,TF_Status * status,int64_t * value)294 bool ParseDimensionValue(const string& key, PyObject* py_value,
295 TF_Status* status, int64_t* value) {
296 if (IsInteger(py_value)) {
297 return ParseInt64Value(key, py_value, status, value);
298 }
299
300 tensorflow::Safe_PyObjectPtr dimension_value(
301 PyObject_GetAttrString(py_value, "_value"));
302 if (dimension_value == nullptr) {
303 PyErr_Clear();
304 TF_SetStatus(
305 status, TF_INVALID_ARGUMENT,
306 tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
307 ", got ", py_value->ob_type->tp_name)
308 .c_str());
309 return false;
310 }
311
312 if (dimension_value.get() == Py_None) {
313 *value = -1;
314 return true;
315 }
316
317 return ParseInt64Value(key, dimension_value.get(), status, value);
318 }
319
ParseStringValue(const string & key,PyObject * py_value,TF_Status * status,tensorflow::StringPiece * value)320 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
321 tensorflow::StringPiece* value) {
322 if (PyBytes_Check(py_value)) {
323 Py_ssize_t size = 0;
324 char* buf = nullptr;
325 if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
326 *value = tensorflow::StringPiece(buf, size);
327 return true;
328 }
329 #if PY_MAJOR_VERSION >= 3
330 if (PyUnicode_Check(py_value)) {
331 Py_ssize_t size = 0;
332 const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
333 if (buf == nullptr) return false;
334 *value = tensorflow::StringPiece(buf, size);
335 return true;
336 }
337 #endif
338 TF_SetStatus(
339 status, TF_INVALID_ARGUMENT,
340 tensorflow::strings::StrCat("Expecting a string value for attr ", key,
341 ", got ", py_value->ob_type->tp_name)
342 .c_str());
343 return false;
344 }
345
ParseBoolValue(const string & key,PyObject * py_value,TF_Status * status,unsigned char * value)346 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
347 unsigned char* value) {
348 *value = PyObject_IsTrue(py_value);
349 return true;
350 }
351
352 // The passed in py_value is expected to be an object of the python type
353 // dtypes.DType or an int.
ParseTypeValue(const string & key,PyObject * py_value,TF_Status * status,int * value)354 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
355 int* value) {
356 if (IsInteger(py_value)) {
357 return ParseIntValue(key, py_value, status, value);
358 }
359
360 tensorflow::Safe_PyObjectPtr py_type_enum(
361 PyObject_GetAttrString(py_value, "_type_enum"));
362 if (py_type_enum == nullptr) {
363 PyErr_Clear();
364 TF_SetStatus(
365 status, TF_INVALID_ARGUMENT,
366 tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
367 ", got ", py_value->ob_type->tp_name)
368 .c_str());
369 return false;
370 }
371
372 return ParseIntValue(key, py_type_enum.get(), status, value);
373 }
374
SetOpAttrList(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_list,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)375 bool SetOpAttrList(
376 TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
377 TF_AttrType type,
378 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
379 TF_Status* status) {
380 if (!PySequence_Check(py_list)) {
381 TF_SetStatus(
382 status, TF_INVALID_ARGUMENT,
383 tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
384 ", got ", py_list->ob_type->tp_name)
385 .c_str());
386 return false;
387 }
388 const int num_values = PySequence_Size(py_list);
389 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
390
391 #define PARSE_LIST(c_type, parse_fn) \
392 std::unique_ptr<c_type[]> values(new c_type[num_values]); \
393 for (int i = 0; i < num_values; ++i) { \
394 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
395 if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
396 }
397
398 if (type == TF_ATTR_STRING) {
399 std::unique_ptr<const void*[]> values(new const void*[num_values]);
400 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
401 for (int i = 0; i < num_values; ++i) {
402 tensorflow::StringPiece value;
403 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
404 if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
405 values[i] = value.data();
406 lengths[i] = value.size();
407 }
408 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
409 } else if (type == TF_ATTR_INT) {
410 PARSE_LIST(int64_t, ParseInt64Value);
411 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
412 } else if (type == TF_ATTR_FLOAT) {
413 PARSE_LIST(float, ParseFloatValue);
414 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
415 } else if (type == TF_ATTR_BOOL) {
416 PARSE_LIST(unsigned char, ParseBoolValue);
417 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
418 } else if (type == TF_ATTR_TYPE) {
419 PARSE_LIST(int, ParseTypeValue);
420 TFE_OpSetAttrTypeList(op, key,
421 reinterpret_cast<const TF_DataType*>(values.get()),
422 num_values);
423 } else if (type == TF_ATTR_SHAPE) {
424 // Make one pass through the input counting the total number of
425 // dims across all the input lists.
426 int total_dims = 0;
427 for (int i = 0; i < num_values; ++i) {
428 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
429 if (py_value.get() != Py_None) {
430 if (!PySequence_Check(py_value.get())) {
431 TF_SetStatus(
432 status, TF_INVALID_ARGUMENT,
433 tensorflow::strings::StrCat(
434 "Expecting None or sequence value for element", i,
435 " of attr ", key, ", got ", py_value->ob_type->tp_name)
436 .c_str());
437 return false;
438 }
439 const auto size = TensorShapeNumDims(py_value.get());
440 if (size >= 0) {
441 total_dims += size;
442 }
443 }
444 }
445 // Allocate a buffer that can fit all of the dims together.
446 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
447 // Copy the input dims into the buffer and set dims to point to
448 // the start of each list's dims.
449 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
450 std::unique_ptr<int[]> num_dims(new int[num_values]);
451 int64_t* offset = buffer.get();
452 for (int i = 0; i < num_values; ++i) {
453 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
454 if (py_value.get() == Py_None) {
455 dims[i] = nullptr;
456 num_dims[i] = -1;
457 } else {
458 const auto size = TensorShapeNumDims(py_value.get());
459 if (size == -1) {
460 dims[i] = nullptr;
461 num_dims[i] = -1;
462 continue;
463 }
464 dims[i] = offset;
465 num_dims[i] = size;
466 for (int j = 0; j < size; ++j) {
467 tensorflow::Safe_PyObjectPtr inner_py_value(
468 PySequence_ITEM(py_value.get(), j));
469 if (inner_py_value.get() == Py_None) {
470 *offset = -1;
471 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
472 offset)) {
473 return false;
474 }
475 ++offset;
476 }
477 }
478 }
479 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
480 status);
481 if (!status->status.ok()) return false;
482 } else if (type == TF_ATTR_FUNC) {
483 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
484 for (int i = 0; i < num_values; ++i) {
485 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
486 // Allow:
487 // (1) String function name, OR
488 // (2) A Python object with a .name attribute
489 // (A crude test for being a
490 // tensorflow.python.framework.function._DefinedFunction)
491 // (which is what the various "defun" or "Defun" decorators do).
492 // And in the future also allow an object that can encapsulate
493 // the function name and its attribute values.
494 tensorflow::StringPiece func_name;
495 if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
496 PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
497 if (name_attr == nullptr ||
498 !ParseStringValue(key, name_attr, status, &func_name)) {
499 TF_SetStatus(
500 status, TF_INVALID_ARGUMENT,
501 tensorflow::strings::StrCat(
502 "unable to set function value attribute from a ",
503 py_value.get()->ob_type->tp_name,
504 " object. If you think this is an error, please file an "
505 "issue at "
506 "https://github.com/tensorflow/tensorflow/issues/new")
507 .c_str());
508 return false;
509 }
510 }
511 funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
512 if (!status->status.ok()) return false;
513 }
514 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
515 if (!status->status.ok()) return false;
516 } else {
517 TF_SetStatus(status, TF_UNIMPLEMENTED,
518 tensorflow::strings::StrCat("Attr ", key,
519 " has unhandled list type ", type)
520 .c_str());
521 return false;
522 }
523 #undef PARSE_LIST
524 return true;
525 }
526
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)527 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
528 TF_Status* status) {
529 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
530 for (const auto& attr : func.attr()) {
531 if (!status->status.ok()) return nullptr;
532 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
533 if (!status->status.ok()) return nullptr;
534 }
535 return func_op;
536 }
537
SetOpAttrListDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * key,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)538 void SetOpAttrListDefault(
539 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
540 const char* key, TF_AttrType type,
541 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
542 TF_Status* status) {
543 if (type == TF_ATTR_STRING) {
544 int num_values = attr.default_value().list().s_size();
545 std::unique_ptr<const void*[]> values(new const void*[num_values]);
546 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
547 (*attr_list_sizes)[key] = num_values;
548 for (int i = 0; i < num_values; i++) {
549 const string& v = attr.default_value().list().s(i);
550 values[i] = v.data();
551 lengths[i] = v.size();
552 }
553 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
554 } else if (type == TF_ATTR_INT) {
555 int num_values = attr.default_value().list().i_size();
556 std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
557 (*attr_list_sizes)[key] = num_values;
558 for (int i = 0; i < num_values; i++) {
559 values[i] = attr.default_value().list().i(i);
560 }
561 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
562 } else if (type == TF_ATTR_FLOAT) {
563 int num_values = attr.default_value().list().f_size();
564 std::unique_ptr<float[]> values(new float[num_values]);
565 (*attr_list_sizes)[key] = num_values;
566 for (int i = 0; i < num_values; i++) {
567 values[i] = attr.default_value().list().f(i);
568 }
569 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
570 } else if (type == TF_ATTR_BOOL) {
571 int num_values = attr.default_value().list().b_size();
572 std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
573 (*attr_list_sizes)[key] = num_values;
574 for (int i = 0; i < num_values; i++) {
575 values[i] = attr.default_value().list().b(i);
576 }
577 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
578 } else if (type == TF_ATTR_TYPE) {
579 int num_values = attr.default_value().list().type_size();
580 std::unique_ptr<int[]> values(new int[num_values]);
581 (*attr_list_sizes)[key] = num_values;
582 for (int i = 0; i < num_values; i++) {
583 values[i] = attr.default_value().list().type(i);
584 }
585 TFE_OpSetAttrTypeList(op, key,
586 reinterpret_cast<const TF_DataType*>(values.get()),
587 attr.default_value().list().type_size());
588 } else if (type == TF_ATTR_SHAPE) {
589 int num_values = attr.default_value().list().shape_size();
590 (*attr_list_sizes)[key] = num_values;
591 int total_dims = 0;
592 for (int i = 0; i < num_values; ++i) {
593 if (!attr.default_value().list().shape(i).unknown_rank()) {
594 total_dims += attr.default_value().list().shape(i).dim_size();
595 }
596 }
597 // Allocate a buffer that can fit all of the dims together.
598 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
599 // Copy the input dims into the buffer and set dims to point to
600 // the start of each list's dims.
601 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
602 std::unique_ptr<int[]> num_dims(new int[num_values]);
603 int64_t* offset = buffer.get();
604 for (int i = 0; i < num_values; ++i) {
605 const auto& shape = attr.default_value().list().shape(i);
606 if (shape.unknown_rank()) {
607 dims[i] = nullptr;
608 num_dims[i] = -1;
609 } else {
610 for (int j = 0; j < shape.dim_size(); j++) {
611 *offset = shape.dim(j).size();
612 ++offset;
613 }
614 }
615 }
616 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
617 status);
618 } else if (type == TF_ATTR_FUNC) {
619 int num_values = attr.default_value().list().func_size();
620 (*attr_list_sizes)[key] = num_values;
621 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
622 for (int i = 0; i < num_values; i++) {
623 funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
624 }
625 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
626 } else {
627 TF_SetStatus(status, TF_UNIMPLEMENTED,
628 "Lists of tensors are not yet implemented for default valued "
629 "attributes for an operation.");
630 }
631 }
632
SetOpAttrScalar(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_value,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)633 bool SetOpAttrScalar(
634 TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
635 TF_AttrType type,
636 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
637 TF_Status* status) {
638 if (type == TF_ATTR_STRING) {
639 tensorflow::StringPiece value;
640 if (!ParseStringValue(key, py_value, status, &value)) return false;
641 TFE_OpSetAttrString(op, key, value.data(), value.size());
642 } else if (type == TF_ATTR_INT) {
643 int64_t value;
644 if (!ParseInt64Value(key, py_value, status, &value)) return false;
645 TFE_OpSetAttrInt(op, key, value);
646 // attr_list_sizes is set for all int attributes (since at this point we are
647 // not aware if that attribute might be used to calculate the size of an
648 // output list or not).
649 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
650 } else if (type == TF_ATTR_FLOAT) {
651 float value;
652 if (!ParseFloatValue(key, py_value, status, &value)) return false;
653 TFE_OpSetAttrFloat(op, key, value);
654 } else if (type == TF_ATTR_BOOL) {
655 unsigned char value;
656 if (!ParseBoolValue(key, py_value, status, &value)) return false;
657 TFE_OpSetAttrBool(op, key, value);
658 } else if (type == TF_ATTR_TYPE) {
659 int value;
660 if (!ParseTypeValue(key, py_value, status, &value)) return false;
661 TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
662 } else if (type == TF_ATTR_SHAPE) {
663 if (py_value == Py_None) {
664 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
665 } else {
666 if (!PySequence_Check(py_value)) {
667 TF_SetStatus(status, TF_INVALID_ARGUMENT,
668 tensorflow::strings::StrCat(
669 "Expecting None or sequence value for attr", key,
670 ", got ", py_value->ob_type->tp_name)
671 .c_str());
672 return false;
673 }
674 const auto num_dims = TensorShapeNumDims(py_value);
675 if (num_dims == -1) {
676 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
677 return true;
678 }
679 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
680 for (int i = 0; i < num_dims; ++i) {
681 tensorflow::Safe_PyObjectPtr inner_py_value(
682 PySequence_ITEM(py_value, i));
683 if (inner_py_value.get() == Py_None) {
684 dims[i] = -1;
685 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
686 &dims[i])) {
687 return false;
688 }
689 }
690 TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
691 }
692 if (!status->status.ok()) return false;
693 } else if (type == TF_ATTR_FUNC) {
694 // Allow:
695 // (1) String function name, OR
696 // (2) A Python object with a .name attribute
697 // (A crude test for being a
698 // tensorflow.python.framework.function._DefinedFunction)
699 // (which is what the various "defun" or "Defun" decorators do).
700 // And in the future also allow an object that can encapsulate
701 // the function name and its attribute values.
702 tensorflow::StringPiece func_name;
703 if (!ParseStringValue(key, py_value, status, &func_name)) {
704 PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
705 if (name_attr == nullptr ||
706 !ParseStringValue(key, name_attr, status, &func_name)) {
707 TF_SetStatus(
708 status, TF_INVALID_ARGUMENT,
709 tensorflow::strings::StrCat(
710 "unable to set function value attribute from a ",
711 py_value->ob_type->tp_name,
712 " object. If you think this is an error, please file an issue "
713 "at https://github.com/tensorflow/tensorflow/issues/new")
714 .c_str());
715 return false;
716 }
717 }
718 TF_SetStatus(status, TF_OK, "");
719 TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
720 } else {
721 TF_SetStatus(
722 status, TF_UNIMPLEMENTED,
723 tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
724 .c_str());
725 return false;
726 }
727 return true;
728 }
729
SetOpAttrScalarDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)730 void SetOpAttrScalarDefault(
731 TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
732 const char* attr_name,
733 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
734 TF_Status* status) {
735 SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
736 if (default_value.value_case() == tensorflow::AttrValue::kI) {
737 (*attr_list_sizes)[attr_name] = default_value.i();
738 }
739 }
740
741 // start_index is the index at which the Tuple/List attrs will start getting
742 // processed.
SetOpAttrs(TFE_Context * ctx,TFE_Op * op,PyObject * attrs,int start_index,TF_Status * out_status)743 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
744 TF_Status* out_status) {
745 if (attrs == Py_None) return;
746 Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
747 if ((len & 1) != 0) {
748 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
749 "Expecting attrs tuple to have even length.");
750 return;
751 }
752 // Parse attrs
753 for (Py_ssize_t i = 0; i < len; i += 2) {
754 PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
755 PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
756 #if PY_MAJOR_VERSION >= 3
757 const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
758 : PyUnicode_AsUTF8(py_key);
759 #else
760 const char* key = PyBytes_AsString(py_key);
761 #endif
762 unsigned char is_list = 0;
763 const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
764 if (!out_status->status.ok()) return;
765 if (is_list != 0) {
766 if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
767 return;
768 } else {
769 if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
770 return;
771 }
772 }
773 }
774
775 // This function will set the op attrs required. If an attr has the value of
776 // None, then it will read the AttrDef to get the default value and set that
777 // instead. Any failure in this function will simply fall back to the slow
778 // path.
SetOpAttrWithDefaults(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * attr_name,PyObject * attr_value,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)779 void SetOpAttrWithDefaults(
780 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
781 const char* attr_name, PyObject* attr_value,
782 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
783 TF_Status* status) {
784 unsigned char is_list = 0;
785 const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
786 if (!status->status.ok()) return;
787 if (attr_value == Py_None) {
788 if (is_list != 0) {
789 SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
790 status);
791 } else {
792 SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
793 attr_list_sizes, status);
794 }
795 } else {
796 if (is_list != 0) {
797 SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
798 status);
799 } else {
800 SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
801 status);
802 }
803 }
804 }
805
GetPythonObjectFromInt(int num)806 PyObject* GetPythonObjectFromInt(int num) {
807 #if PY_MAJOR_VERSION >= 3
808 return PyLong_FromLong(num);
809 #else
810 return PyInt_FromLong(num);
811 #endif
812 }
813
814 // Python subclass of Exception that is created on not ok Status.
815 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
816 PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
817
818 // Python subclass of Exception that is created to signal fallback.
819 PyObject* fallback_exception_class = nullptr;
820
821 // Python function that returns input gradients given output gradients.
822 PyObject* gradient_function = nullptr;
823
824 // Python function that returns output gradients given input gradients.
825 PyObject* forward_gradient_function = nullptr;
826
827 static std::atomic<int64_t> _uid;
828
829 } // namespace
830
GetStatus()831 TF_Status* GetStatus() {
832 TF_Status* maybe_status = ReleaseThreadLocalStatus();
833 if (maybe_status) {
834 TF_SetStatus(maybe_status, TF_OK, "");
835 return maybe_status;
836 } else {
837 return TF_NewStatus();
838 }
839 }
840
ReturnStatus(TF_Status * status)841 void ReturnStatus(TF_Status* status) {
842 TF_SetStatus(status, TF_OK, "");
843 thread_local_tf_status.reset(status);
844 }
845
TFE_Py_Execute(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_OutputTensorHandles * outputs,TF_Status * out_status)846 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
847 const char* op_name, TFE_InputTensorHandles* inputs,
848 PyObject* attrs, TFE_OutputTensorHandles* outputs,
849 TF_Status* out_status) {
850 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
851 /*cancellation_manager=*/nullptr, outputs,
852 out_status);
853 }
854
TFE_Py_ExecuteCancelable(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_CancellationManager * cancellation_manager,TFE_OutputTensorHandles * outputs,TF_Status * out_status)855 void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
856 const char* op_name,
857 TFE_InputTensorHandles* inputs, PyObject* attrs,
858 TFE_CancellationManager* cancellation_manager,
859 TFE_OutputTensorHandles* outputs,
860 TF_Status* out_status) {
861 tensorflow::profiler::TraceMe activity(
862 "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo);
863
864 TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
865
866 auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
867 if (!out_status->status.ok()) return;
868
869 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
870 tensorflow::StackTrace::kStackTraceInitialSize));
871
872 for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
873 TFE_OpAddInput(op, inputs->at(i), out_status);
874 }
875 if (cancellation_manager && out_status->status.ok()) {
876 TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
877 }
878 if (out_status->status.ok()) {
879 SetOpAttrs(ctx, op, attrs, 0, out_status);
880 }
881 Py_BEGIN_ALLOW_THREADS;
882
883 int num_outputs = outputs->size();
884
885 if (out_status->status.ok()) {
886 TFE_Execute(op, outputs->data(), &num_outputs, out_status);
887 }
888
889 if (out_status->status.ok()) {
890 outputs->resize(num_outputs);
891 } else {
892 TF_SetStatus(out_status, TF_GetCode(out_status),
893 tensorflow::strings::StrCat(TF_Message(out_status),
894 " [Op:", op_name, "]")
895 .c_str());
896 }
897
898 Py_END_ALLOW_THREADS;
899 }
900
TFE_Py_RegisterExceptionClass(PyObject * e)901 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
902 tensorflow::mutex_lock l(exception_class_mutex);
903 if (exception_class != nullptr) {
904 Py_DECREF(exception_class);
905 }
906 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
907 exception_class = nullptr;
908 PyErr_SetString(PyExc_TypeError,
909 "TFE_Py_RegisterExceptionClass: "
910 "Registered class should be subclass of Exception.");
911 return nullptr;
912 }
913
914 Py_INCREF(e);
915 exception_class = e;
916 Py_RETURN_NONE;
917 }
918
TFE_Py_RegisterFallbackExceptionClass(PyObject * e)919 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
920 if (fallback_exception_class != nullptr) {
921 Py_DECREF(fallback_exception_class);
922 }
923 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
924 fallback_exception_class = nullptr;
925 PyErr_SetString(PyExc_TypeError,
926 "TFE_Py_RegisterFallbackExceptionClass: "
927 "Registered class should be subclass of Exception.");
928 return nullptr;
929 } else {
930 Py_INCREF(e);
931 fallback_exception_class = e;
932 Py_RETURN_NONE;
933 }
934 }
935
TFE_Py_RegisterGradientFunction(PyObject * e)936 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
937 if (gradient_function != nullptr) {
938 Py_DECREF(gradient_function);
939 }
940 if (!PyCallable_Check(e)) {
941 gradient_function = nullptr;
942 PyErr_SetString(PyExc_TypeError,
943 "TFE_Py_RegisterGradientFunction: "
944 "Registered object should be function.");
945 return nullptr;
946 } else {
947 Py_INCREF(e);
948 gradient_function = e;
949 Py_RETURN_NONE;
950 }
951 }
952
TFE_Py_RegisterJVPFunction(PyObject * e)953 PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
954 if (forward_gradient_function != nullptr) {
955 Py_DECREF(forward_gradient_function);
956 }
957 if (!PyCallable_Check(e)) {
958 forward_gradient_function = nullptr;
959 PyErr_SetString(PyExc_TypeError,
960 "TFE_Py_RegisterJVPFunction: "
961 "Registered object should be function.");
962 return nullptr;
963 } else {
964 Py_INCREF(e);
965 forward_gradient_function = e;
966 Py_RETURN_NONE;
967 }
968 }
969
RaiseFallbackException(const char * message)970 void RaiseFallbackException(const char* message) {
971 if (fallback_exception_class != nullptr) {
972 PyErr_SetString(fallback_exception_class, message);
973 return;
974 }
975
976 PyErr_SetString(
977 PyExc_RuntimeError,
978 tensorflow::strings::StrCat(
979 "Fallback exception type not set, attempting to fallback due to ",
980 message)
981 .data());
982 }
983
984 // Format and return `status`' error message with the attached stack trace if
985 // available. `status` must have an error.
FormatErrorStatusStackTrace(const tensorflow::Status & status)986 std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) {
987 tensorflow::DCheckPyGilState();
988 DCHECK(!status.ok());
989
990 if (status.stack_trace().empty()) return status.error_message();
991
992 const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace();
993
994 PyObject* linecache = PyImport_ImportModule("linecache");
995 PyObject* getline =
996 PyObject_GetAttr(linecache, PyUnicode_FromString("getline"));
997 DCHECK(getline);
998
999 std::ostringstream result;
1000 result << "Exception originated from\n\n";
1001
1002 for (const tensorflow::StackFrame& stack_frame : stack_trace) {
1003 PyObject* line_str_obj = PyObject_CallFunction(
1004 getline, const_cast<char*>("si"), stack_frame.file_name.c_str(),
1005 stack_frame.line_number);
1006 tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj);
1007 tensorflow::str_util::RemoveWhitespaceContext(&line_str);
1008 result << " File \"" << stack_frame.file_name << "\", line "
1009 << stack_frame.line_number << ", in " << stack_frame.function_name
1010 << '\n';
1011
1012 if (!line_str.empty()) result << " " << line_str << '\n';
1013 Py_XDECREF(line_str_obj);
1014 }
1015
1016 Py_DecRef(getline);
1017 Py_DecRef(linecache);
1018
1019 result << '\n' << status.error_message();
1020 return result.str();
1021 }
1022
MaybeRaiseExceptionFromTFStatus(TF_Status * status,PyObject * exception)1023 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
1024 if (status->status.ok()) return 0;
1025 const char* msg = TF_Message(status);
1026 if (exception == nullptr) {
1027 tensorflow::mutex_lock l(exception_class_mutex);
1028 if (exception_class != nullptr) {
1029 tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1030 "si", FormatErrorStatusStackTrace(status->status).c_str(),
1031 TF_GetCode(status)));
1032 if (PyErr_Occurred()) {
1033 // NOTE: This hides the actual error (i.e. the reason `status` was not
1034 // TF_OK), but there is nothing we can do at this point since we can't
1035 // generate a reasonable error from the status.
1036 // Consider adding a message explaining this.
1037 return -1;
1038 }
1039 PyErr_SetObject(exception_class, val.get());
1040 return -1;
1041 } else {
1042 exception = PyExc_RuntimeError;
1043 }
1044 }
1045 // May be update already set exception.
1046 PyErr_SetString(exception, msg);
1047 return -1;
1048 }
1049
MaybeRaiseExceptionFromStatus(const tensorflow::Status & status,PyObject * exception)1050 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
1051 PyObject* exception) {
1052 if (status.ok()) return 0;
1053 const char* msg = status.error_message().c_str();
1054 if (exception == nullptr) {
1055 tensorflow::mutex_lock l(exception_class_mutex);
1056 if (exception_class != nullptr) {
1057 tensorflow::Safe_PyObjectPtr val(Py_BuildValue(
1058 "si", FormatErrorStatusStackTrace(status).c_str(), status.code()));
1059 PyErr_SetObject(exception_class, val.get());
1060 return -1;
1061 } else {
1062 exception = PyExc_RuntimeError;
1063 }
1064 }
1065 // May be update already set exception.
1066 PyErr_SetString(exception, msg);
1067 return -1;
1068 }
1069
TFE_GetPythonString(PyObject * o)1070 const char* TFE_GetPythonString(PyObject* o) {
1071 #if PY_MAJOR_VERSION >= 3
1072 if (PyBytes_Check(o)) {
1073 return PyBytes_AsString(o);
1074 } else {
1075 return PyUnicode_AsUTF8(o);
1076 }
1077 #else
1078 return PyBytes_AsString(o);
1079 #endif
1080 }
1081
get_uid()1082 int64_t get_uid() { return _uid++; }
1083
TFE_Py_UID()1084 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
1085
TFE_DeleteContextCapsule(PyObject * context)1086 void TFE_DeleteContextCapsule(PyObject* context) {
1087 TFE_Context* ctx =
1088 reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
1089 auto op = ReleaseThreadLocalOp(ctx);
1090 op.reset();
1091 TFE_DeleteContext(ctx);
1092 }
1093
MakeInt(PyObject * integer)1094 static tensorflow::int64 MakeInt(PyObject* integer) {
1095 #if PY_MAJOR_VERSION >= 3
1096 return PyLong_AsLong(integer);
1097 #else
1098 return PyInt_AsLong(integer);
1099 #endif
1100 }
1101
FastTensorId(PyObject * tensor)1102 static tensorflow::int64 FastTensorId(PyObject* tensor) {
1103 if (EagerTensor_CheckExact(tensor)) {
1104 return PyEagerTensor_ID(tensor);
1105 }
1106 PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
1107 if (id_field == nullptr) {
1108 return -1;
1109 }
1110 tensorflow::int64 id = MakeInt(id_field);
1111 Py_DECREF(id_field);
1112 return id;
1113 }
1114
1115 namespace tensorflow {
PyTensor_DataType(PyObject * tensor)1116 DataType PyTensor_DataType(PyObject* tensor) {
1117 if (EagerTensor_CheckExact(tensor)) {
1118 return PyEagerTensor_Dtype(tensor);
1119 } else {
1120 #if PY_MAJOR_VERSION < 3
1121 // Python 2.x:
1122 static PyObject* dtype_attr = PyString_InternFromString("dtype");
1123 static PyObject* type_enum_attr = PyString_InternFromString("_type_enum");
1124 #else
1125 // Python 3.x:
1126 static PyObject* dtype_attr = PyUnicode_InternFromString("dtype");
1127 static PyObject* type_enum_attr = PyUnicode_InternFromString("_type_enum");
1128 #endif
1129 Safe_PyObjectPtr dtype_field(PyObject_GetAttr(tensor, dtype_attr));
1130 if (!dtype_field) {
1131 return DT_INVALID;
1132 }
1133
1134 Safe_PyObjectPtr enum_field(
1135 PyObject_GetAttr(dtype_field.get(), type_enum_attr));
1136 if (!enum_field) {
1137 return DT_INVALID;
1138 }
1139
1140 return static_cast<DataType>(MakeInt(enum_field.get()));
1141 }
1142 }
1143 } // namespace tensorflow
1144
1145 class PyTapeTensor {
1146 public:
PyTapeTensor(tensorflow::int64 id,tensorflow::DataType dtype,const tensorflow::TensorShape & shape)1147 PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
1148 const tensorflow::TensorShape& shape)
1149 : id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(tensorflow::int64 id,tensorflow::DataType dtype,PyObject * shape)1150 PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
1151 PyObject* shape)
1152 : id_(id), dtype_(dtype), shape_(shape) {
1153 Py_INCREF(absl::get<1>(shape_));
1154 }
PyTapeTensor(const PyTapeTensor & other)1155 PyTapeTensor(const PyTapeTensor& other) {
1156 id_ = other.id_;
1157 dtype_ = other.dtype_;
1158 shape_ = other.shape_;
1159 if (shape_.index() == 1) {
1160 Py_INCREF(absl::get<1>(shape_));
1161 }
1162 }
1163
~PyTapeTensor()1164 ~PyTapeTensor() {
1165 if (shape_.index() == 1) {
1166 Py_DECREF(absl::get<1>(shape_));
1167 }
1168 }
1169 PyObject* GetShape() const;
GetPyDType() const1170 PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
GetID() const1171 tensorflow::int64 GetID() const { return id_; }
GetDType() const1172 tensorflow::DataType GetDType() const { return dtype_; }
1173
1174 PyObject* OnesLike() const;
1175 PyObject* ZerosLike() const;
1176
1177 private:
1178 tensorflow::int64 id_;
1179 tensorflow::DataType dtype_;
1180
1181 // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
1182 // PyObject is the tensor itself. This is used to support tf.shape(tensor) for
1183 // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
1184 // tensors.
1185 absl::variant<tensorflow::TensorShape, PyObject*> shape_;
1186 };
1187
1188 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
1189
1190 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
1191 PyTapeTensor> {
1192 public:
PyVSpace(PyObject * py_vspace)1193 explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
1194 Py_INCREF(py_vspace_);
1195 }
1196
Initialize()1197 tensorflow::Status Initialize() {
1198 num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
1199 if (num_elements_ == nullptr) {
1200 return tensorflow::errors::InvalidArgument("invalid vspace");
1201 }
1202 aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
1203 if (aggregate_fn_ == nullptr) {
1204 return tensorflow::errors::InvalidArgument("invalid vspace");
1205 }
1206 zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
1207 if (zeros_fn_ == nullptr) {
1208 return tensorflow::errors::InvalidArgument("invalid vspace");
1209 }
1210 zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
1211 if (zeros_like_fn_ == nullptr) {
1212 return tensorflow::errors::InvalidArgument("invalid vspace");
1213 }
1214 ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
1215 if (ones_fn_ == nullptr) {
1216 return tensorflow::errors::InvalidArgument("invalid vspace");
1217 }
1218 ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
1219 if (ones_like_fn_ == nullptr) {
1220 return tensorflow::errors::InvalidArgument("invalid vspace");
1221 }
1222 graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
1223 if (graph_shape_fn_ == nullptr) {
1224 return tensorflow::errors::InvalidArgument("invalid vspace");
1225 }
1226 return tensorflow::Status::OK();
1227 }
1228
~PyVSpace()1229 ~PyVSpace() override {
1230 Py_XDECREF(num_elements_);
1231 Py_XDECREF(aggregate_fn_);
1232 Py_XDECREF(zeros_fn_);
1233 Py_XDECREF(zeros_like_fn_);
1234 Py_XDECREF(ones_fn_);
1235 Py_XDECREF(ones_like_fn_);
1236 Py_XDECREF(graph_shape_fn_);
1237
1238 Py_DECREF(py_vspace_);
1239 }
1240
NumElements(PyObject * tensor) const1241 tensorflow::int64 NumElements(PyObject* tensor) const final {
1242 if (EagerTensor_CheckExact(tensor)) {
1243 return PyEagerTensor_NumElements(tensor);
1244 }
1245 PyObject* arglist =
1246 Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1247 PyObject* result = PyEval_CallObject(num_elements_, arglist);
1248 Py_DECREF(arglist);
1249 if (result == nullptr) {
1250 // The caller detects whether a python exception has been raised.
1251 return -1;
1252 }
1253 tensorflow::int64 r = MakeInt(result);
1254 Py_DECREF(result);
1255 return r;
1256 }
1257
AggregateGradients(tensorflow::gtl::ArraySlice<PyObject * > gradient_tensors) const1258 PyObject* AggregateGradients(
1259 tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1260 PyObject* list = PyList_New(gradient_tensors.size());
1261 for (int i = 0; i < gradient_tensors.size(); ++i) {
1262 // Note: stealing a reference to the gradient tensors.
1263 CHECK(gradient_tensors[i] != nullptr);
1264 CHECK(gradient_tensors[i] != Py_None);
1265 PyList_SET_ITEM(list, i,
1266 reinterpret_cast<PyObject*>(gradient_tensors[i]));
1267 }
1268 PyObject* arglist = Py_BuildValue("(O)", list);
1269 CHECK(arglist != nullptr);
1270 PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1271 Py_DECREF(arglist);
1272 Py_DECREF(list);
1273 return result;
1274 }
1275
TensorId(PyObject * tensor) const1276 tensorflow::int64 TensorId(PyObject* tensor) const final {
1277 return FastTensorId(tensor);
1278 }
1279
MarkAsResult(PyObject * gradient) const1280 void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1281
Ones(PyObject * shape,PyObject * dtype) const1282 PyObject* Ones(PyObject* shape, PyObject* dtype) const {
1283 if (PyErr_Occurred()) {
1284 return nullptr;
1285 }
1286 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1287 PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1288 Py_DECREF(arg_list);
1289 return result;
1290 }
1291
OnesLike(PyObject * tensor) const1292 PyObject* OnesLike(PyObject* tensor) const {
1293 if (PyErr_Occurred()) {
1294 return nullptr;
1295 }
1296 return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
1297 }
1298
1299 // Builds a tensor filled with ones with the same shape and dtype as `t`.
BuildOnesLike(const PyTapeTensor & t,PyObject ** result) const1300 Status BuildOnesLike(const PyTapeTensor& t,
1301 PyObject** result) const override {
1302 *result = t.OnesLike();
1303 return Status::OK();
1304 }
1305
Zeros(PyObject * shape,PyObject * dtype) const1306 PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
1307 if (PyErr_Occurred()) {
1308 return nullptr;
1309 }
1310 PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
1311 PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1312 Py_DECREF(arg_list);
1313 return result;
1314 }
1315
ZerosLike(PyObject * tensor) const1316 PyObject* ZerosLike(PyObject* tensor) const {
1317 if (PyErr_Occurred()) {
1318 return nullptr;
1319 }
1320 return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
1321 }
1322
GraphShape(PyObject * tensor) const1323 PyObject* GraphShape(PyObject* tensor) const {
1324 PyObject* arg_list = Py_BuildValue("(O)", tensor);
1325 PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1326 Py_DECREF(arg_list);
1327 return result;
1328 }
1329
CallBackwardFunction(const string & op_type,PyBackwardFunction * backward_function,const std::vector<tensorflow::int64> & unneeded_gradients,tensorflow::gtl::ArraySlice<PyObject * > output_gradients,absl::Span<PyObject * > result) const1330 tensorflow::Status CallBackwardFunction(
1331 const string& op_type, PyBackwardFunction* backward_function,
1332 const std::vector<tensorflow::int64>& unneeded_gradients,
1333 tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1334 absl::Span<PyObject*> result) const final {
1335 PyObject* grads = PyTuple_New(output_gradients.size());
1336 for (int i = 0; i < output_gradients.size(); ++i) {
1337 if (output_gradients[i] == nullptr) {
1338 Py_INCREF(Py_None);
1339 PyTuple_SET_ITEM(grads, i, Py_None);
1340 } else {
1341 PyTuple_SET_ITEM(grads, i,
1342 reinterpret_cast<PyObject*>(output_gradients[i]));
1343 }
1344 }
1345 PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
1346 Py_DECREF(grads);
1347 if (py_result == nullptr) {
1348 return tensorflow::errors::Internal("gradient function threw exceptions");
1349 }
1350 PyObject* seq =
1351 PySequence_Fast(py_result, "expected a sequence of gradients");
1352 if (seq == nullptr) {
1353 return tensorflow::errors::InvalidArgument(
1354 "gradient function did not return a list");
1355 }
1356 int len = PySequence_Fast_GET_SIZE(seq);
1357 if (len != result.size()) {
1358 return tensorflow::errors::Internal(
1359 "Recorded operation '", op_type,
1360 "' returned too few gradients. Expected ", result.size(),
1361 " but received ", len);
1362 }
1363 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1364 VLOG(1) << "Gradient length is " << len;
1365 for (int i = 0; i < len; ++i) {
1366 PyObject* item = seq_array[i];
1367 if (item == Py_None) {
1368 result[i] = nullptr;
1369 } else {
1370 Py_INCREF(item);
1371 result[i] = item;
1372 }
1373 }
1374 Py_DECREF(seq);
1375 Py_DECREF(py_result);
1376 return tensorflow::Status::OK();
1377 }
1378
DeleteGradient(PyObject * tensor) const1379 void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1380
TapeTensorFromGradient(PyObject * tensor) const1381 PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
1382 return TapeTensorFromTensor(tensor);
1383 }
1384
1385 private:
1386 PyObject* py_vspace_;
1387
1388 PyObject* num_elements_;
1389 PyObject* aggregate_fn_;
1390 PyObject* zeros_fn_;
1391 PyObject* zeros_like_fn_;
1392 PyObject* ones_fn_;
1393 PyObject* ones_like_fn_;
1394 PyObject* graph_shape_fn_;
1395 };
1396 PyVSpace* py_vspace = nullptr;
1397
1398 bool HasAccumulator();
1399
TFE_Py_RegisterVSpace(PyObject * e)1400 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1401 if (py_vspace != nullptr) {
1402 if (HasAccumulator()) {
1403 // Accumulators reference py_vspace, so we can't swap it out while one is
1404 // active. This is unlikely to ever happen.
1405 MaybeRaiseExceptionFromStatus(
1406 tensorflow::errors::Internal(
1407 "Can't change the vspace implementation while a "
1408 "forward accumulator is active."),
1409 nullptr);
1410 }
1411 delete py_vspace;
1412 }
1413
1414 py_vspace = new PyVSpace(e);
1415 auto status = py_vspace->Initialize();
1416 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1417 delete py_vspace;
1418 return nullptr;
1419 }
1420
1421 Py_RETURN_NONE;
1422 }
1423
GetShape() const1424 PyObject* PyTapeTensor::GetShape() const {
1425 if (shape_.index() == 0) {
1426 auto& shape = absl::get<0>(shape_);
1427 PyObject* py_shape = PyTuple_New(shape.dims());
1428 for (int i = 0; i < shape.dims(); ++i) {
1429 PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1430 }
1431
1432 return py_shape;
1433 }
1434
1435 return py_vspace->GraphShape(absl::get<1>(shape_));
1436 }
1437
OnesLike() const1438 PyObject* PyTapeTensor::OnesLike() const {
1439 if (shape_.index() == 1) {
1440 PyObject* tensor = absl::get<1>(shape_);
1441 return py_vspace->OnesLike(tensor);
1442 }
1443 PyObject* py_shape = GetShape();
1444 PyObject* dtype_field = GetPyDType();
1445 PyObject* result = py_vspace->Ones(py_shape, dtype_field);
1446 Py_DECREF(dtype_field);
1447 Py_DECREF(py_shape);
1448 return result;
1449 }
1450
ZerosLike() const1451 PyObject* PyTapeTensor::ZerosLike() const {
1452 if (shape_.index() == 1) {
1453 PyObject* tensor = absl::get<1>(shape_);
1454 return py_vspace->ZerosLike(tensor);
1455 }
1456 PyObject* py_shape = GetShape();
1457 PyObject* dtype_field = GetPyDType();
1458 PyObject* result = py_vspace->Zeros(py_shape, dtype_field);
1459 Py_DECREF(dtype_field);
1460 Py_DECREF(py_shape);
1461 return result;
1462 }
1463
1464 // Keeps track of all variables that have been accessed during execution.
1465 class VariableWatcher {
1466 public:
VariableWatcher()1467 VariableWatcher() {}
1468
~VariableWatcher()1469 ~VariableWatcher() {
1470 for (const IdAndVariable& v : watched_variables_) {
1471 Py_DECREF(v.variable);
1472 }
1473 }
1474
WatchVariable(PyObject * v)1475 tensorflow::int64 WatchVariable(PyObject* v) {
1476 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1477 if (handle == nullptr) {
1478 return -1;
1479 }
1480 tensorflow::int64 id = FastTensorId(handle.get());
1481
1482 tensorflow::mutex_lock l(watched_variables_mu_);
1483 auto insert_result = watched_variables_.emplace(id, v);
1484
1485 if (insert_result.second) {
1486 // Only increment the reference count if we aren't already watching this
1487 // variable.
1488 Py_INCREF(v);
1489 }
1490
1491 return id;
1492 }
1493
GetVariablesAsPyTuple()1494 PyObject* GetVariablesAsPyTuple() {
1495 tensorflow::mutex_lock l(watched_variables_mu_);
1496 PyObject* result = PyTuple_New(watched_variables_.size());
1497 Py_ssize_t pos = 0;
1498 for (const IdAndVariable& id_and_variable : watched_variables_) {
1499 PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1500 Py_INCREF(id_and_variable.variable);
1501 }
1502 return result;
1503 }
1504
1505 private:
1506 // We store an IdAndVariable in the map since the map needs to be locked
1507 // during insert, but should not call back into python during insert to avoid
1508 // deadlocking with the GIL.
1509 struct IdAndVariable {
1510 tensorflow::int64 id;
1511 PyObject* variable;
1512
IdAndVariableVariableWatcher::IdAndVariable1513 IdAndVariable(tensorflow::int64 id, PyObject* variable)
1514 : id(id), variable(variable) {}
1515 };
1516 struct CompareById {
operator ()VariableWatcher::CompareById1517 bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1518 return lhs.id < rhs.id;
1519 }
1520 };
1521
1522 tensorflow::mutex watched_variables_mu_;
1523 std::set<IdAndVariable, CompareById> watched_variables_
1524 TF_GUARDED_BY(watched_variables_mu_);
1525 };
1526
1527 class GradientTape
1528 : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1529 PyTapeTensor> {
1530 public:
GradientTape(bool persistent,bool watch_accessed_variables)1531 explicit GradientTape(bool persistent, bool watch_accessed_variables)
1532 : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1533 PyTapeTensor>(persistent),
1534 watch_accessed_variables_(watch_accessed_variables) {}
1535
~GradientTape()1536 virtual ~GradientTape() {}
1537
VariableAccessed(PyObject * v)1538 void VariableAccessed(PyObject* v) {
1539 if (watch_accessed_variables_) {
1540 WatchVariable(v);
1541 }
1542 }
1543
WatchVariable(PyObject * v)1544 void WatchVariable(PyObject* v) {
1545 tensorflow::int64 id = variable_watcher_.WatchVariable(v);
1546
1547 if (!PyErr_Occurred()) {
1548 this->Watch(id);
1549 }
1550 }
1551
GetVariablesAsPyTuple()1552 PyObject* GetVariablesAsPyTuple() {
1553 return variable_watcher_.GetVariablesAsPyTuple();
1554 }
1555
1556 private:
1557 bool watch_accessed_variables_;
1558 VariableWatcher variable_watcher_;
1559 };
1560
1561 typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
1562 PyTapeTensor>
1563 ForwardAccumulator;
1564
1565 // Incremented when a GradientTape or accumulator is newly added to a set, and
1566 // used to enforce an ordering between them.
1567 std::atomic_uint_fast64_t tape_nesting_id_counter(0);
1568
1569 typedef struct {
1570 PyObject_HEAD
1571 /* Type-specific fields go here. */
1572 GradientTape* tape;
1573 // A nesting order between GradientTapes and ForwardAccumulators, used to
1574 // ensure that GradientTapes do not watch the products of outer
1575 // ForwardAccumulators.
1576 tensorflow::int64 nesting_id;
1577 } TFE_Py_Tape;
1578
TFE_Py_Tape_Delete(PyObject * tape)1579 static void TFE_Py_Tape_Delete(PyObject* tape) {
1580 delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1581 Py_TYPE(tape)->tp_free(tape);
1582 }
1583
1584 static PyTypeObject TFE_Py_Tape_Type = {
1585 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1586 sizeof(TFE_Py_Tape), /* tp_basicsize */
1587 0, /* tp_itemsize */
1588 &TFE_Py_Tape_Delete, /* tp_dealloc */
1589 #if PY_VERSION_HEX < 0x03080000
1590 nullptr, /* tp_print */
1591 #else
1592 0, /* tp_vectorcall_offset */
1593 #endif
1594 nullptr, /* tp_getattr */
1595 nullptr, /* tp_setattr */
1596 nullptr, /* tp_reserved */
1597 nullptr, /* tp_repr */
1598 nullptr, /* tp_as_number */
1599 nullptr, /* tp_as_sequence */
1600 nullptr, /* tp_as_mapping */
1601 nullptr, /* tp_hash */
1602 nullptr, /* tp_call */
1603 nullptr, /* tp_str */
1604 nullptr, /* tp_getattro */
1605 nullptr, /* tp_setattro */
1606 nullptr, /* tp_as_buffer */
1607 Py_TPFLAGS_DEFAULT, /* tp_flags */
1608 "TFE_Py_Tape objects", /* tp_doc */
1609 };
1610
1611 typedef struct {
1612 PyObject_HEAD
1613 /* Type-specific fields go here. */
1614 ForwardAccumulator* accumulator;
1615 // A nesting order between GradientTapes and ForwardAccumulators, used to
1616 // ensure that GradientTapes do not watch the products of outer
1617 // ForwardAccumulators.
1618 tensorflow::int64 nesting_id;
1619 } TFE_Py_ForwardAccumulator;
1620
TFE_Py_ForwardAccumulatorDelete(PyObject * accumulator)1621 static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
1622 delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
1623 Py_TYPE(accumulator)->tp_free(accumulator);
1624 }
1625
1626 static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
1627 PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
1628 sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
1629 0, /* tp_itemsize */
1630 &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
1631 #if PY_VERSION_HEX < 0x03080000
1632 nullptr, /* tp_print */
1633 #else
1634 0, /* tp_vectorcall_offset */
1635 #endif
1636 nullptr, /* tp_getattr */
1637 nullptr, /* tp_setattr */
1638 nullptr, /* tp_reserved */
1639 nullptr, /* tp_repr */
1640 nullptr, /* tp_as_number */
1641 nullptr, /* tp_as_sequence */
1642 nullptr, /* tp_as_mapping */
1643 nullptr, /* tp_hash */
1644 nullptr, /* tp_call */
1645 nullptr, /* tp_str */
1646 nullptr, /* tp_getattro */
1647 nullptr, /* tp_setattro */
1648 nullptr, /* tp_as_buffer */
1649 Py_TPFLAGS_DEFAULT, /* tp_flags */
1650 "TFE_Py_ForwardAccumulator objects", /* tp_doc */
1651 };
1652
1653 typedef struct {
1654 PyObject_HEAD
1655 /* Type-specific fields go here. */
1656 VariableWatcher* variable_watcher;
1657 } TFE_Py_VariableWatcher;
1658
TFE_Py_VariableWatcher_Delete(PyObject * variable_watcher)1659 static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
1660 delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
1661 ->variable_watcher;
1662 Py_TYPE(variable_watcher)->tp_free(variable_watcher);
1663 }
1664
1665 static PyTypeObject TFE_Py_VariableWatcher_Type = {
1666 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
1667 sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
1668 0, /* tp_itemsize */
1669 &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
1670 #if PY_VERSION_HEX < 0x03080000
1671 nullptr, /* tp_print */
1672 #else
1673 0, /* tp_vectorcall_offset */
1674 #endif
1675 nullptr, /* tp_getattr */
1676 nullptr, /* tp_setattr */
1677 nullptr, /* tp_reserved */
1678 nullptr, /* tp_repr */
1679 nullptr, /* tp_as_number */
1680 nullptr, /* tp_as_sequence */
1681 nullptr, /* tp_as_mapping */
1682 nullptr, /* tp_hash */
1683 nullptr, /* tp_call */
1684 nullptr, /* tp_str */
1685 nullptr, /* tp_getattro */
1686 nullptr, /* tp_setattro */
1687 nullptr, /* tp_as_buffer */
1688 Py_TPFLAGS_DEFAULT, /* tp_flags */
1689 "TFE_Py_VariableWatcher objects", /* tp_doc */
1690 };
1691
1692 // Note: in the current design no mutex is needed here because of the python
1693 // GIL, which is always held when any TFE_Py_* methods are called. We should
1694 // revisit this if/when decide to not hold the GIL while manipulating the tape
1695 // stack.
GetTapeSet()1696 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1697 thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
1698 tape_set = nullptr;
1699 if (tape_set == nullptr) {
1700 tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
1701 }
1702 return tape_set.get();
1703 }
1704
1705 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
GetVariableWatcherSet()1706 GetVariableWatcherSet() {
1707 thread_local std::unique_ptr<
1708 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
1709 variable_watcher_set = nullptr;
1710 if (variable_watcher_set == nullptr) {
1711 variable_watcher_set.reset(
1712 new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
1713 }
1714 return variable_watcher_set.get();
1715 }
1716
1717 // A linked hash set, where iteration is in insertion order.
1718 //
1719 // Nested accumulators rely on op recording happening in insertion order, so an
1720 // unordered data structure like CompactPointerSet is not suitable. Outer
1721 // accumulators need to observe operations first so they know to watch the inner
1722 // accumulator's jvp computation.
1723 //
1724 // Not thread safe.
1725 class AccumulatorSet {
1726 public:
1727 // Returns true if `element` was newly inserted, false if it already exists.
insert(TFE_Py_ForwardAccumulator * element)1728 bool insert(TFE_Py_ForwardAccumulator* element) {
1729 if (map_.find(element) != map_.end()) {
1730 return false;
1731 }
1732 ListType::iterator it = ordered_.insert(ordered_.end(), element);
1733 map_.insert(std::make_pair(element, it));
1734 return true;
1735 }
1736
erase(TFE_Py_ForwardAccumulator * element)1737 void erase(TFE_Py_ForwardAccumulator* element) {
1738 MapType::iterator existing = map_.find(element);
1739 if (existing == map_.end()) {
1740 return;
1741 }
1742 ListType::iterator list_position = existing->second;
1743 map_.erase(existing);
1744 ordered_.erase(list_position);
1745 }
1746
empty() const1747 bool empty() const { return ordered_.empty(); }
1748
size() const1749 size_t size() const { return ordered_.size(); }
1750
1751 private:
1752 typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
1753 typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
1754 ListType::iterator>
1755 MapType;
1756
1757 public:
1758 typedef ListType::const_iterator const_iterator;
1759 typedef ListType::const_reverse_iterator const_reverse_iterator;
1760
begin() const1761 const_iterator begin() const { return ordered_.begin(); }
end() const1762 const_iterator end() const { return ordered_.end(); }
1763
rbegin() const1764 const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
rend() const1765 const_reverse_iterator rend() const { return ordered_.rend(); }
1766
1767 private:
1768 MapType map_;
1769 ListType ordered_;
1770 };
1771
GetAccumulatorSet()1772 AccumulatorSet* GetAccumulatorSet() {
1773 thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
1774 if (accumulator_set == nullptr) {
1775 accumulator_set.reset(new AccumulatorSet);
1776 }
1777 return accumulator_set.get();
1778 }
1779
HasAccumulator()1780 inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
1781
HasGradientTape()1782 inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
1783
HasAccumulatorOrTape()1784 inline bool HasAccumulatorOrTape() {
1785 return HasGradientTape() || HasAccumulator();
1786 }
1787
1788 // A safe copy of a set, used for tapes and accumulators. The copy is not
1789 // affected by other python threads changing the set of active tapes.
1790 template <typename ContainerType>
1791 class SafeSetCopy {
1792 public:
SafeSetCopy(const ContainerType & to_copy)1793 explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
1794 for (auto* member : set_copy_) {
1795 Py_INCREF(member);
1796 }
1797 }
1798
~SafeSetCopy()1799 ~SafeSetCopy() {
1800 for (auto* member : set_copy_) {
1801 Py_DECREF(member);
1802 }
1803 }
1804
begin() const1805 typename ContainerType::const_iterator begin() const {
1806 return set_copy_.begin();
1807 }
1808
end() const1809 typename ContainerType::const_iterator end() const { return set_copy_.end(); }
1810
empty() const1811 bool empty() const { return set_copy_.empty(); }
size() const1812 size_t size() const { return set_copy_.size(); }
1813
1814 protected:
1815 ContainerType set_copy_;
1816 };
1817
1818 class SafeTapeSet
1819 : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
1820 public:
SafeTapeSet()1821 SafeTapeSet()
1822 : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
1823 *GetTapeSet()) {}
1824 };
1825
1826 class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
1827 public:
SafeAccumulatorSet()1828 SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
1829
rbegin() const1830 typename AccumulatorSet::const_reverse_iterator rbegin() const {
1831 return set_copy_.rbegin();
1832 }
1833
rend() const1834 typename AccumulatorSet::const_reverse_iterator rend() const {
1835 return set_copy_.rend();
1836 }
1837 };
1838
1839 class SafeVariableWatcherSet
1840 : public SafeSetCopy<
1841 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
1842 public:
SafeVariableWatcherSet()1843 SafeVariableWatcherSet()
1844 : SafeSetCopy<
1845 tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
1846 *GetVariableWatcherSet()) {}
1847 };
1848
ThreadTapeIsStopped()1849 bool* ThreadTapeIsStopped() {
1850 thread_local bool thread_tape_is_stopped{false};
1851 return &thread_tape_is_stopped;
1852 }
1853
TFE_Py_TapeSetStopOnThread()1854 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1855
TFE_Py_TapeSetRestartOnThread()1856 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1857
TFE_Py_TapeSetIsStopped()1858 PyObject* TFE_Py_TapeSetIsStopped() {
1859 if (*ThreadTapeIsStopped()) {
1860 Py_RETURN_TRUE;
1861 }
1862 Py_RETURN_FALSE;
1863 }
1864
TFE_Py_TapeSetNew(PyObject * persistent,PyObject * watch_accessed_variables)1865 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1866 PyObject* watch_accessed_variables) {
1867 TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1868 if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1869 TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1870 tape->tape = new GradientTape(persistent == Py_True,
1871 watch_accessed_variables == Py_True);
1872 Py_INCREF(tape);
1873 tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1874 GetTapeSet()->insert(tape);
1875 return reinterpret_cast<PyObject*>(tape);
1876 }
1877
TFE_Py_TapeSetAdd(PyObject * tape)1878 void TFE_Py_TapeSetAdd(PyObject* tape) {
1879 Py_INCREF(tape);
1880 TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
1881 if (!GetTapeSet()->insert(tfe_tape).second) {
1882 // Already exists in the tape set.
1883 Py_DECREF(tape);
1884 } else {
1885 tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
1886 }
1887 }
1888
TFE_Py_TapeSetIsEmpty()1889 PyObject* TFE_Py_TapeSetIsEmpty() {
1890 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
1891 Py_RETURN_TRUE;
1892 }
1893 Py_RETURN_FALSE;
1894 }
1895
TFE_Py_TapeSetRemove(PyObject * tape)1896 void TFE_Py_TapeSetRemove(PyObject* tape) {
1897 auto* stack = GetTapeSet();
1898 stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1899 // We kept a reference to the tape in the set to ensure it wouldn't get
1900 // deleted under us; cleaning it up here.
1901 Py_DECREF(tape);
1902 }
1903
MakeIntList(PyObject * list)1904 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
1905 if (list == Py_None) {
1906 return {};
1907 }
1908 PyObject* seq = PySequence_Fast(list, "expected a sequence");
1909 if (seq == nullptr) {
1910 return {};
1911 }
1912 int len = PySequence_Size(list);
1913 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
1914 std::vector<tensorflow::int64> tensor_ids;
1915 tensor_ids.reserve(len);
1916 for (int i = 0; i < len; ++i) {
1917 PyObject* item = seq_array[i];
1918 #if PY_MAJOR_VERSION >= 3
1919 if (PyLong_Check(item)) {
1920 #else
1921 if (PyLong_Check(item) || PyInt_Check(item)) {
1922 #endif
1923 tensorflow::int64 id = MakeInt(item);
1924 tensor_ids.push_back(id);
1925 } else {
1926 tensor_ids.push_back(-1);
1927 }
1928 }
1929 Py_DECREF(seq);
1930 return tensor_ids;
1931 }
1932
1933 // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
1934 // null. Returns true on success and false on a Python exception.
1935 bool TensorShapesAndDtypes(PyObject* tensors,
1936 std::vector<tensorflow::int64>* tensor_ids,
1937 std::vector<tensorflow::DataType>* dtypes) {
1938 tensorflow::Safe_PyObjectPtr seq(
1939 PySequence_Fast(tensors, "expected a sequence"));
1940 if (seq == nullptr) {
1941 return false;
1942 }
1943 int len = PySequence_Fast_GET_SIZE(seq.get());
1944 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
1945 tensor_ids->reserve(len);
1946 dtypes->reserve(len);
1947 for (int i = 0; i < len; ++i) {
1948 PyObject* item = seq_array[i];
1949 tensor_ids->push_back(FastTensorId(item));
1950 dtypes->push_back(tensorflow::PyTensor_DataType(item));
1951 }
1952 return true;
1953 }
1954
1955 bool TapeCouldPossiblyRecord(PyObject* tensors) {
1956 if (tensors == Py_None) {
1957 return false;
1958 }
1959 if (*ThreadTapeIsStopped()) {
1960 return false;
1961 }
1962 if (!HasAccumulatorOrTape()) {
1963 return false;
1964 }
1965 return true;
1966 }
1967
1968 bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
1969
1970 bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
1971
1972 PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
1973 if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
1974 Py_RETURN_FALSE;
1975 }
1976 // TODO(apassos) consider not building a list and changing the API to check
1977 // each tensor individually.
1978 std::vector<tensorflow::int64> tensor_ids;
1979 std::vector<tensorflow::DataType> dtypes;
1980 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
1981 return nullptr;
1982 }
1983 auto tape_set = *GetTapeSet();
1984 for (TFE_Py_Tape* tape : tape_set) {
1985 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
1986 Py_RETURN_TRUE;
1987 }
1988 }
1989
1990 Py_RETURN_FALSE;
1991 }
1992
1993 PyObject* TFE_Py_ForwardAccumulatorPushState() {
1994 auto forward_accumulators = *GetAccumulatorSet();
1995 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
1996 accumulator->accumulator->PushState();
1997 }
1998 Py_RETURN_NONE;
1999 }
2000
2001 PyObject* TFE_Py_ForwardAccumulatorPopState() {
2002 auto forward_accumulators = *GetAccumulatorSet();
2003 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2004 accumulator->accumulator->PopState();
2005 }
2006 Py_RETURN_NONE;
2007 }
2008
2009 PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
2010 if (!TapeCouldPossiblyRecord(tensors)) {
2011 return GetPythonObjectFromInt(0);
2012 }
2013 std::vector<tensorflow::int64> tensor_ids;
2014 std::vector<tensorflow::DataType> dtypes;
2015 if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
2016 return nullptr;
2017 }
2018
2019 // If there is a persistent tape watching, or if there are multiple tapes
2020 // watching, we'll return immediately indicating that higher-order tape
2021 // gradients are possible.
2022 bool some_tape_watching = false;
2023 if (CouldBackprop()) {
2024 auto tape_set = *GetTapeSet();
2025 for (TFE_Py_Tape* tape : tape_set) {
2026 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
2027 if (tape->tape->IsPersistent() || some_tape_watching) {
2028 // Either this is the second tape watching, or this tape is
2029 // persistent: higher-order gradients are possible.
2030 return GetPythonObjectFromInt(2);
2031 }
2032 some_tape_watching = true;
2033 }
2034 }
2035 }
2036 if (CouldForwardprop()) {
2037 auto forward_accumulators = *GetAccumulatorSet();
2038 for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
2039 if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
2040 if (some_tape_watching) {
2041 // This is the second tape watching: higher-order gradients are
2042 // possible. Note that there's no equivalent of persistence for
2043 // forward-mode.
2044 return GetPythonObjectFromInt(2);
2045 }
2046 some_tape_watching = true;
2047 }
2048 }
2049 }
2050 if (some_tape_watching) {
2051 // There's exactly one non-persistent tape. The user can request first-order
2052 // gradients but won't be able to get higher-order tape gradients.
2053 return GetPythonObjectFromInt(1);
2054 } else {
2055 // There are no tapes. The user can't request tape gradients.
2056 return GetPythonObjectFromInt(0);
2057 }
2058 }
2059
2060 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
2061 if (!CouldBackprop()) {
2062 return;
2063 }
2064 tensorflow::int64 tensor_id = FastTensorId(tensor);
2065 if (PyErr_Occurred()) {
2066 return;
2067 }
2068 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
2069 }
2070
2071 bool ListContainsNone(PyObject* list) {
2072 if (list == Py_None) return true;
2073 tensorflow::Safe_PyObjectPtr seq(
2074 PySequence_Fast(list, "expected a sequence"));
2075 if (seq == nullptr) {
2076 return false;
2077 }
2078
2079 int len = PySequence_Size(list);
2080 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2081 for (int i = 0; i < len; ++i) {
2082 PyObject* item = seq_array[i];
2083 if (item == Py_None) return true;
2084 }
2085
2086 return false;
2087 }
2088
2089 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
2090 if (EagerTensor_CheckExact(tensor)) {
2091 tensorflow::ImmediateExecutionTensorHandle* handle =
2092 tensorflow::unwrap(EagerTensor_Handle(tensor));
2093 tensorflow::int64 id = PyEagerTensor_ID(tensor);
2094 tensorflow::DataType dtype =
2095 static_cast<tensorflow::DataType>(handle->DataType());
2096 if (dtype == tensorflow::DT_VARIANT) {
2097 return PyTapeTensor(id, dtype, tensor);
2098 }
2099
2100 tensorflow::TensorShape tensor_shape;
2101 int num_dims;
2102 tensorflow::Status status = handle->NumDims(&num_dims);
2103 if (status.ok()) {
2104 for (int i = 0; i < num_dims; ++i) {
2105 tensorflow::int64 dim_size;
2106 status = handle->Dim(i, &dim_size);
2107 if (!status.ok()) break;
2108 tensor_shape.AddDim(dim_size);
2109 }
2110 }
2111
2112 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2113 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2114 tensorflow::TensorShape({}));
2115 } else {
2116 return PyTapeTensor(id, dtype, tensor_shape);
2117 }
2118 }
2119 tensorflow::int64 id = FastTensorId(tensor);
2120 if (PyErr_Occurred()) {
2121 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2122 tensorflow::TensorShape({}));
2123 }
2124 PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
2125 PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
2126 Py_DECREF(dtype_object);
2127 tensorflow::DataType dtype =
2128 static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
2129 Py_DECREF(dtype_enum);
2130 if (PyErr_Occurred()) {
2131 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2132 tensorflow::TensorShape({}));
2133 }
2134 static char _shape_tuple[] = "_shape_tuple";
2135 tensorflow::Safe_PyObjectPtr shape_tuple(
2136 PyObject_CallMethod(tensor, _shape_tuple, nullptr));
2137 if (PyErr_Occurred()) {
2138 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
2139 tensorflow::TensorShape({}));
2140 }
2141
2142 if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) {
2143 return PyTapeTensor(id, dtype, tensor);
2144 }
2145
2146 auto l = MakeIntList(shape_tuple.get());
2147 // Replace -1, which represents accidental Nones which can occur in graph mode
2148 // and can cause errors in shape construction with 0s.
2149 for (auto& c : l) {
2150 if (c < 0) {
2151 c = 0;
2152 }
2153 }
2154 tensorflow::TensorShape shape(l);
2155 return PyTapeTensor(id, dtype, shape);
2156 }
2157
2158 // Populates output_info from output_seq, which must come from PySequence_Fast.
2159 //
2160 // Does not take ownership of output_seq. Returns true on success and false if a
2161 // Python exception has been set.
2162 bool TapeTensorsFromTensorSequence(PyObject* output_seq,
2163 std::vector<PyTapeTensor>* output_info) {
2164 Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
2165 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2166 output_info->reserve(output_len);
2167 for (Py_ssize_t i = 0; i < output_len; ++i) {
2168 output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
2169 if (PyErr_Occurred() != nullptr) {
2170 return false;
2171 }
2172 }
2173 return true;
2174 }
2175
2176 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
2177 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2178 if (seq == nullptr) {
2179 return {};
2180 }
2181 int len = PySequence_Fast_GET_SIZE(seq);
2182 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2183 std::vector<tensorflow::int64> list;
2184 list.reserve(len);
2185 for (int i = 0; i < len; ++i) {
2186 PyObject* tensor = seq_array[i];
2187 list.push_back(FastTensorId(tensor));
2188 if (PyErr_Occurred()) {
2189 Py_DECREF(seq);
2190 return list;
2191 }
2192 }
2193 Py_DECREF(seq);
2194 return list;
2195 }
2196
2197 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
2198 if (!CouldBackprop()) {
2199 return;
2200 }
2201 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2202 tape->tape->VariableAccessed(variable);
2203 }
2204 }
2205
2206 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
2207 if (!CouldBackprop()) {
2208 return;
2209 }
2210 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
2211 }
2212
2213 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
2214 return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
2215 }
2216
2217 PyObject* TFE_Py_VariableWatcherNew() {
2218 TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
2219 if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
2220 TFE_Py_VariableWatcher* variable_watcher =
2221 PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
2222 variable_watcher->variable_watcher = new VariableWatcher();
2223 Py_INCREF(variable_watcher);
2224 GetVariableWatcherSet()->insert(variable_watcher);
2225 return reinterpret_cast<PyObject*>(variable_watcher);
2226 }
2227
2228 void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
2229 auto* stack = GetVariableWatcherSet();
2230 stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
2231 // We kept a reference to the variable watcher in the set to ensure it
2232 // wouldn't get deleted under us; cleaning it up here.
2233 Py_DECREF(variable_watcher);
2234 }
2235
2236 void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
2237 for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
2238 variable_watcher->variable_watcher->WatchVariable(variable);
2239 }
2240 }
2241
2242 PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
2243 return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
2244 ->variable_watcher->GetVariablesAsPyTuple();
2245 }
2246
2247 namespace {
2248 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
2249 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2250 if (seq == nullptr) {
2251 return {};
2252 }
2253 int len = PySequence_Fast_GET_SIZE(seq);
2254 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2255 std::vector<tensorflow::DataType> list;
2256 list.reserve(len);
2257 for (int i = 0; i < len; ++i) {
2258 PyObject* tensor = seq_array[i];
2259 list.push_back(tensorflow::PyTensor_DataType(tensor));
2260 }
2261 Py_DECREF(seq);
2262 return list;
2263 }
2264
2265 PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
2266 PyObject* weak_tensor_ref) {
2267 tensorflow::int64 parsed_tensor_id = MakeInt(tensor_id);
2268 for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
2269 accumulator->accumulator->DeleteGradient(parsed_tensor_id);
2270 }
2271 Py_DECREF(weak_tensor_ref);
2272 Py_DECREF(tensor_id);
2273 Py_INCREF(Py_None);
2274 return Py_None;
2275 }
2276
2277 static PyMethodDef forward_accumulator_delete_gradient_method_def = {
2278 "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
2279 METH_O, "ForwardAccumulatorDeleteGradient"};
2280
2281 void RegisterForwardAccumulatorCleanup(PyObject* tensor,
2282 tensorflow::int64 tensor_id) {
2283 tensorflow::Safe_PyObjectPtr callback(
2284 PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
2285 PyLong_FromLong(tensor_id)));
2286 // We need to keep a reference to the weakref active if we want our callback
2287 // called. The callback itself now owns the weakref object and the tensor ID
2288 // object.
2289 PyWeakref_NewRef(tensor, callback.get());
2290 }
2291
2292 void TapeSetRecordBackprop(
2293 const string& op_type, const std::vector<PyTapeTensor>& output_info,
2294 const std::vector<tensorflow::int64>& input_ids,
2295 const std::vector<tensorflow::DataType>& input_dtypes,
2296 const std::function<PyBackwardFunction*()>& backward_function_getter,
2297 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2298 tensorflow::uint64 max_gradient_tape_id) {
2299 if (!CouldBackprop()) {
2300 return;
2301 }
2302 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2303 if (tape->nesting_id < max_gradient_tape_id) {
2304 tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
2305 backward_function_getter,
2306 backward_function_killer);
2307 }
2308 }
2309 }
2310
2311 bool TapeSetRecordForwardprop(
2312 const string& op_type, PyObject* output_seq,
2313 const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
2314 const std::vector<tensorflow::int64>& input_ids,
2315 const std::vector<tensorflow::DataType>& input_dtypes,
2316 const std::function<PyBackwardFunction*()>& backward_function_getter,
2317 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2318 const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
2319 PyObject* forwardprop_output_indices,
2320 tensorflow::uint64* max_gradient_tape_id) {
2321 *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
2322 if (!CouldForwardprop()) {
2323 return true;
2324 }
2325 auto accumulator_set = SafeAccumulatorSet();
2326 tensorflow::Safe_PyObjectPtr input_seq(
2327 PySequence_Fast(input_tensors, "expected a sequence of tensors"));
2328 if (input_seq == nullptr || PyErr_Occurred()) return false;
2329 Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
2330 PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
2331 for (int i = 0; i < output_info.size(); ++i) {
2332 RegisterForwardAccumulatorCleanup(output_seq_array[i],
2333 output_info[i].GetID());
2334 }
2335 if (forwardprop_output_indices != nullptr &&
2336 forwardprop_output_indices != Py_None) {
2337 tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
2338 forwardprop_output_indices, "Expected a sequence of indices"));
2339 if (indices_fast == nullptr || PyErr_Occurred()) {
2340 return false;
2341 }
2342 if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
2343 accumulator_set.size()) {
2344 MaybeRaiseExceptionFromStatus(
2345 tensorflow::errors::Internal(
2346 "Accumulators were added or removed from the active set "
2347 "between packing and unpacking."),
2348 nullptr);
2349 }
2350 PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
2351 Py_ssize_t accumulator_index = 0;
2352 for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
2353 it != accumulator_set.rend(); ++it, ++accumulator_index) {
2354 tensorflow::Safe_PyObjectPtr jvp_index_seq(
2355 PySequence_Fast(indices_fast_array[accumulator_index],
2356 "Expected a sequence of jvp indices."));
2357 if (jvp_index_seq == nullptr || PyErr_Occurred()) {
2358 return false;
2359 }
2360 Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
2361 PyObject** jvp_index_seq_array =
2362 PySequence_Fast_ITEMS(jvp_index_seq.get());
2363 for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
2364 PyObject* tuple = jvp_index_seq_array[jvp_index];
2365 tensorflow::int64 primal_tensor_id =
2366 output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
2367 (*it)->accumulator->Watch(
2368 primal_tensor_id,
2369 output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
2370 }
2371 }
2372 } else {
2373 std::vector<PyTapeTensor> input_info;
2374 input_info.reserve(input_len);
2375 PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
2376 for (Py_ssize_t i = 0; i < input_len; ++i) {
2377 input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
2378 }
2379 for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
2380 tensorflow::Status status = accumulator->accumulator->Accumulate(
2381 op_type, input_info, output_info, input_ids, input_dtypes,
2382 forward_function, backward_function_getter, backward_function_killer);
2383 if (PyErr_Occurred()) return false; // Don't swallow Python exceptions.
2384 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
2385 return false;
2386 }
2387 if (accumulator->accumulator->BusyAccumulating()) {
2388 // Ensure inner accumulators don't see outer accumulators' jvps. This
2389 // mostly happens on its own, with some potentially surprising
2390 // exceptions, so the blanket policy is for consistency.
2391 *max_gradient_tape_id = accumulator->nesting_id;
2392 break;
2393 }
2394 }
2395 }
2396 return true;
2397 }
2398
2399 PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
2400 PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
2401 for (int i = 0; i < input_tangents.size(); ++i) {
2402 PyObject* element;
2403 if (input_tangents[i] == nullptr) {
2404 element = Py_None;
2405 } else {
2406 element = input_tangents[i];
2407 }
2408 Py_INCREF(element);
2409 PyTuple_SET_ITEM(py_input_tangents, i, element);
2410 }
2411 return py_input_tangents;
2412 }
2413
2414 tensorflow::Status ParseTangentOutputs(
2415 PyObject* user_output, std::vector<PyObject*>* output_tangents) {
2416 if (user_output == Py_None) {
2417 // No connected gradients.
2418 return tensorflow::Status::OK();
2419 }
2420 tensorflow::Safe_PyObjectPtr fast_result(
2421 PySequence_Fast(user_output, "expected a sequence of forward gradients"));
2422 if (fast_result == nullptr) {
2423 return tensorflow::errors::InvalidArgument(
2424 "forward gradient function did not return a sequence.");
2425 }
2426 int len = PySequence_Fast_GET_SIZE(fast_result.get());
2427 PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
2428 output_tangents->reserve(len);
2429 for (int i = 0; i < len; ++i) {
2430 PyObject* item = fast_result_array[i];
2431 if (item == Py_None) {
2432 output_tangents->push_back(nullptr);
2433 } else {
2434 Py_INCREF(item);
2435 output_tangents->push_back(item);
2436 }
2437 }
2438 return tensorflow::Status::OK();
2439 }
2440
2441 // Calls the registered forward_gradient_function, computing `output_tangents`
2442 // from `input_tangents`. `output_tangents` must not be null.
2443 //
2444 // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
2445 // the forward function is being called.
2446 tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
2447 PyObject* inputs, PyObject* results,
2448 const std::vector<PyObject*>& input_tangents,
2449 std::vector<PyObject*>* output_tangents,
2450 bool use_batch) {
2451 if (forward_gradient_function == nullptr) {
2452 return tensorflow::errors::Internal(
2453 "No forward gradient function registered.");
2454 }
2455 tensorflow::Safe_PyObjectPtr py_input_tangents(
2456 TangentsAsPyTuple(input_tangents));
2457
2458 // Normalize the input sequence to a tuple so it works with function
2459 // caching; otherwise it may be an opaque _InputList object.
2460 tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
2461 PyObject* to_batch = (use_batch) ? Py_True : Py_False;
2462 tensorflow::Safe_PyObjectPtr callback_args(
2463 Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
2464 py_input_tangents.get(), to_batch));
2465 tensorflow::Safe_PyObjectPtr py_result(
2466 PyObject_CallObject(forward_gradient_function, callback_args.get()));
2467 if (py_result == nullptr || PyErr_Occurred()) {
2468 return tensorflow::errors::Internal(
2469 "forward gradient function threw exceptions");
2470 }
2471 return ParseTangentOutputs(py_result.get(), output_tangents);
2472 }
2473
2474 // Like CallJVPFunction, but calls a pre-bound forward function.
2475 // These are passed in from a record_gradient argument.
2476 tensorflow::Status CallOpSpecificJVPFunction(
2477 PyObject* op_specific_forward_function,
2478 const std::vector<PyObject*>& input_tangents,
2479 std::vector<PyObject*>* output_tangents) {
2480 tensorflow::Safe_PyObjectPtr py_input_tangents(
2481 TangentsAsPyTuple(input_tangents));
2482
2483 tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
2484 op_specific_forward_function, py_input_tangents.get()));
2485 if (py_result == nullptr || PyErr_Occurred()) {
2486 return tensorflow::errors::Internal(
2487 "forward gradient function threw exceptions");
2488 }
2489 return ParseTangentOutputs(py_result.get(), output_tangents);
2490 }
2491
2492 bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
2493 if (PyBytes_Check(op_type)) {
2494 *op_type_string = PyBytes_AsString(op_type);
2495 } else if (PyUnicode_Check(op_type)) {
2496 #if PY_MAJOR_VERSION >= 3
2497 *op_type_string = PyUnicode_AsUTF8(op_type);
2498 #else
2499 PyObject* py_str = PyUnicode_AsUTF8String(op_type);
2500 if (py_str == nullptr) {
2501 return false;
2502 }
2503 *op_type_string = PyBytes_AS_STRING(py_str);
2504 Py_DECREF(py_str);
2505 #endif
2506 } else {
2507 PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
2508 return false;
2509 }
2510 return true;
2511 }
2512
2513 bool TapeSetRecordOperation(
2514 PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
2515 const std::vector<tensorflow::int64>& input_ids,
2516 const std::vector<tensorflow::DataType>& input_dtypes,
2517 const std::function<PyBackwardFunction*()>& backward_function_getter,
2518 const std::function<void(PyBackwardFunction*)>& backward_function_killer,
2519 const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
2520 std::vector<PyTapeTensor> output_info;
2521 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2522 output_tensors, "expected a sequence of integer tensor ids"));
2523 if (PyErr_Occurred() ||
2524 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2525 return false;
2526 }
2527 string op_type_str;
2528 if (!ParseOpTypeString(op_type, &op_type_str)) {
2529 return false;
2530 }
2531 tensorflow::uint64 max_gradient_tape_id;
2532 if (!TapeSetRecordForwardprop(
2533 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2534 input_dtypes, backward_function_getter, backward_function_killer,
2535 forward_function, nullptr /* No special-cased jvps. */,
2536 &max_gradient_tape_id)) {
2537 return false;
2538 }
2539 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2540 backward_function_getter, backward_function_killer,
2541 max_gradient_tape_id);
2542 return true;
2543 }
2544 } // namespace
2545
2546 PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
2547 PyObject* output_tensors,
2548 PyObject* input_tensors,
2549 PyObject* backward_function,
2550 PyObject* forward_function) {
2551 if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
2552 Py_RETURN_NONE;
2553 }
2554 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2555 if (PyErr_Occurred()) return nullptr;
2556
2557 std::vector<tensorflow::DataType> input_dtypes =
2558 MakeTensorDtypeList(input_tensors);
2559 if (PyErr_Occurred()) return nullptr;
2560
2561 std::function<PyBackwardFunction*()> backward_function_getter(
2562 [backward_function]() {
2563 Py_INCREF(backward_function);
2564 PyBackwardFunction* function = new PyBackwardFunction(
2565 [backward_function](PyObject* out_grads,
2566 const std::vector<tensorflow::int64>& unused) {
2567 return PyObject_CallObject(backward_function, out_grads);
2568 });
2569 return function;
2570 });
2571 std::function<void(PyBackwardFunction*)> backward_function_killer(
2572 [backward_function](PyBackwardFunction* py_backward_function) {
2573 Py_DECREF(backward_function);
2574 delete py_backward_function;
2575 });
2576
2577 if (forward_function == Py_None) {
2578 if (!TapeSetRecordOperation(
2579 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2580 backward_function_getter, backward_function_killer,
2581 nullptr /* No special-cased forward function */)) {
2582 return nullptr;
2583 }
2584 } else {
2585 tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
2586 [forward_function](const std::vector<PyObject*>& input_tangents,
2587 std::vector<PyObject*>* output_tangents,
2588 bool use_batch = false) {
2589 return CallOpSpecificJVPFunction(forward_function, input_tangents,
2590 output_tangents);
2591 });
2592 if (!TapeSetRecordOperation(
2593 op_type, input_tensors, output_tensors, input_ids, input_dtypes,
2594 backward_function_getter, backward_function_killer,
2595 &wrapped_forward_function)) {
2596 return nullptr;
2597 }
2598 }
2599 Py_RETURN_NONE;
2600 }
2601
2602 PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
2603 PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
2604 PyObject* backward_function, PyObject* forwardprop_output_indices) {
2605 if (!HasAccumulator() || *ThreadTapeIsStopped()) {
2606 Py_RETURN_NONE;
2607 }
2608 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2609 if (PyErr_Occurred()) return nullptr;
2610
2611 std::vector<tensorflow::DataType> input_dtypes =
2612 MakeTensorDtypeList(input_tensors);
2613 if (PyErr_Occurred()) return nullptr;
2614
2615 std::function<PyBackwardFunction*()> backward_function_getter(
2616 [backward_function]() {
2617 Py_INCREF(backward_function);
2618 PyBackwardFunction* function = new PyBackwardFunction(
2619 [backward_function](PyObject* out_grads,
2620 const std::vector<tensorflow::int64>& unused) {
2621 return PyObject_CallObject(backward_function, out_grads);
2622 });
2623 return function;
2624 });
2625 std::function<void(PyBackwardFunction*)> backward_function_killer(
2626 [backward_function](PyBackwardFunction* py_backward_function) {
2627 Py_DECREF(backward_function);
2628 delete py_backward_function;
2629 });
2630 std::vector<PyTapeTensor> output_info;
2631 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2632 output_tensors, "expected a sequence of integer tensor ids"));
2633 if (PyErr_Occurred() ||
2634 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2635 return nullptr;
2636 }
2637 string op_type_str;
2638 if (!ParseOpTypeString(op_type, &op_type_str)) {
2639 return nullptr;
2640 }
2641 tensorflow::uint64 max_gradient_tape_id;
2642 if (!TapeSetRecordForwardprop(
2643 op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
2644 input_dtypes, backward_function_getter, backward_function_killer,
2645 nullptr /* no special-cased forward function */,
2646 forwardprop_output_indices, &max_gradient_tape_id)) {
2647 return nullptr;
2648 }
2649 Py_RETURN_NONE;
2650 }
2651
2652 PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
2653 PyObject* output_tensors,
2654 PyObject* input_tensors,
2655 PyObject* backward_function) {
2656 if (!CouldBackprop()) {
2657 Py_RETURN_NONE;
2658 }
2659 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
2660 if (PyErr_Occurred()) return nullptr;
2661
2662 std::vector<tensorflow::DataType> input_dtypes =
2663 MakeTensorDtypeList(input_tensors);
2664 if (PyErr_Occurred()) return nullptr;
2665
2666 std::function<PyBackwardFunction*()> backward_function_getter(
2667 [backward_function]() {
2668 Py_INCREF(backward_function);
2669 PyBackwardFunction* function = new PyBackwardFunction(
2670 [backward_function](PyObject* out_grads,
2671 const std::vector<tensorflow::int64>& unused) {
2672 return PyObject_CallObject(backward_function, out_grads);
2673 });
2674 return function;
2675 });
2676 std::function<void(PyBackwardFunction*)> backward_function_killer(
2677 [backward_function](PyBackwardFunction* py_backward_function) {
2678 Py_DECREF(backward_function);
2679 delete py_backward_function;
2680 });
2681 std::vector<PyTapeTensor> output_info;
2682 tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
2683 output_tensors, "expected a sequence of integer tensor ids"));
2684 if (PyErr_Occurred() ||
2685 !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
2686 return nullptr;
2687 }
2688 string op_type_str;
2689 if (!ParseOpTypeString(op_type, &op_type_str)) {
2690 return nullptr;
2691 }
2692 TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
2693 backward_function_getter, backward_function_killer,
2694 // No filtering based on relative ordering with forward
2695 // accumulators.
2696 std::numeric_limits<tensorflow::uint64>::max());
2697 Py_RETURN_NONE;
2698 }
2699
2700 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
2701 for (TFE_Py_Tape* tape : *GetTapeSet()) {
2702 tape->tape->DeleteTrace(tensor_id);
2703 }
2704 }
2705
2706 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
2707 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
2708 if (seq == nullptr) {
2709 return {};
2710 }
2711 int len = PySequence_Fast_GET_SIZE(seq);
2712 PyObject** seq_array = PySequence_Fast_ITEMS(seq);
2713 std::vector<PyObject*> list(seq_array, seq_array + len);
2714 Py_DECREF(seq);
2715 return list;
2716 }
2717
2718 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
2719 PyObject* sources, PyObject* output_gradients,
2720 PyObject* sources_raw,
2721 PyObject* unconnected_gradients,
2722 TF_Status* status) {
2723 TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
2724 if (!tape_obj->tape->IsPersistent()) {
2725 auto* tape_set = GetTapeSet();
2726 if (tape_set->find(tape_obj) != tape_set->end()) {
2727 PyErr_SetString(PyExc_RuntimeError,
2728 "gradient() cannot be invoked within the "
2729 "GradientTape context (i.e., while operations are being "
2730 "recorded). Either move the call to gradient() to be "
2731 "outside the 'with tf.GradientTape' block, or "
2732 "use a persistent tape: "
2733 "'with tf.GradientTape(persistent=true)'");
2734 return nullptr;
2735 }
2736 }
2737
2738 std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
2739 if (PyErr_Occurred()) {
2740 return nullptr;
2741 }
2742 std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
2743 if (PyErr_Occurred()) {
2744 return nullptr;
2745 }
2746 tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
2747 sources_vec.end());
2748
2749 tensorflow::Safe_PyObjectPtr seq =
2750 tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
2751 int len = PySequence_Fast_GET_SIZE(seq.get());
2752 PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
2753 std::unordered_map<tensorflow::int64, PyTapeTensor>
2754 source_tensors_that_are_targets;
2755 for (int i = 0; i < len; ++i) {
2756 tensorflow::int64 target_id = target_vec[i];
2757 if (sources_set.find(target_id) != sources_set.end()) {
2758 auto tensor = seq_array[i];
2759 source_tensors_that_are_targets.insert(
2760 std::make_pair(target_id, TapeTensorFromTensor(tensor)));
2761 }
2762 if (PyErr_Occurred()) {
2763 return nullptr;
2764 }
2765 }
2766 if (PyErr_Occurred()) {
2767 return nullptr;
2768 }
2769
2770 std::vector<PyObject*> outgrad_vec;
2771 if (output_gradients != Py_None) {
2772 outgrad_vec = MakeTensorList(output_gradients);
2773 if (PyErr_Occurred()) {
2774 return nullptr;
2775 }
2776 for (PyObject* tensor : outgrad_vec) {
2777 // Calling the backward function will eat a reference to the tensors in
2778 // outgrad_vec, so we need to increase their reference count.
2779 Py_INCREF(tensor);
2780 }
2781 }
2782 std::vector<PyObject*> result(sources_vec.size());
2783 status->status = tape_obj->tape->ComputeGradient(
2784 *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
2785 outgrad_vec, absl::MakeSpan(result));
2786 if (!status->status.ok()) {
2787 if (PyErr_Occurred()) {
2788 // Do not propagate the erroneous status as that would swallow the
2789 // exception which caused the problem.
2790 status->status = tensorflow::Status::OK();
2791 }
2792 return nullptr;
2793 }
2794
2795 bool unconnected_gradients_zero =
2796 strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
2797 std::vector<PyObject*> sources_obj;
2798 if (unconnected_gradients_zero) {
2799 // Uses the "raw" sources here so it can properly make a zeros tensor even
2800 // if there are resource variables as sources.
2801 sources_obj = MakeTensorList(sources_raw);
2802 }
2803
2804 if (!result.empty()) {
2805 PyObject* py_result = PyList_New(result.size());
2806 tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
2807 for (int i = 0; i < result.size(); ++i) {
2808 if (result[i] == nullptr) {
2809 if (unconnected_gradients_zero) {
2810 // generate a zeros tensor in the shape of sources[i]
2811 tensorflow::DataType dtype =
2812 tensorflow::PyTensor_DataType(sources_obj[i]);
2813 PyTapeTensor tensor =
2814 PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
2815 result[i] = tensor.ZerosLike();
2816 } else {
2817 Py_INCREF(Py_None);
2818 result[i] = Py_None;
2819 }
2820 } else if (seen_results.find(result[i]) != seen_results.end()) {
2821 Py_INCREF(result[i]);
2822 }
2823 seen_results.insert(result[i]);
2824 PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
2825 }
2826 return py_result;
2827 }
2828 return PyList_New(0);
2829 }
2830
2831 PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
2832 TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
2833 if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
2834 TFE_Py_ForwardAccumulator* accumulator =
2835 PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
2836 if (py_vspace == nullptr) {
2837 MaybeRaiseExceptionFromStatus(
2838 tensorflow::errors::Internal(
2839 "ForwardAccumulator requires a PyVSpace to be registered."),
2840 nullptr);
2841 }
2842 accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
2843 return reinterpret_cast<PyObject*>(accumulator);
2844 }
2845
2846 PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
2847 TFE_Py_ForwardAccumulator* c_accumulator(
2848 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2849 c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
2850 if (GetAccumulatorSet()->insert(c_accumulator)) {
2851 Py_INCREF(accumulator);
2852 Py_RETURN_NONE;
2853 } else {
2854 MaybeRaiseExceptionFromStatus(
2855 tensorflow::errors::Internal(
2856 "A ForwardAccumulator was added to the active set twice."),
2857 nullptr);
2858 return nullptr;
2859 }
2860 }
2861
2862 void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
2863 GetAccumulatorSet()->erase(
2864 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
2865 Py_DECREF(accumulator);
2866 }
2867
2868 void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
2869 PyObject* tangent) {
2870 tensorflow::int64 tensor_id = FastTensorId(tensor);
2871 reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2872 ->accumulator->Watch(tensor_id, tangent);
2873 RegisterForwardAccumulatorCleanup(tensor, tensor_id);
2874 }
2875
2876 // Returns a new reference to the JVP Tensor.
2877 PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
2878 PyObject* tensor) {
2879 PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
2880 ->accumulator->FetchJVP(FastTensorId(tensor));
2881 if (jvp == nullptr) {
2882 jvp = Py_None;
2883 }
2884 Py_INCREF(jvp);
2885 return jvp;
2886 }
2887
2888 PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
2889 if (!TapeCouldPossiblyRecord(tensors)) {
2890 tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
2891 tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
2892 return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
2893 }
2894 auto accumulators = *GetAccumulatorSet();
2895 tensorflow::Safe_PyObjectPtr tensors_fast(
2896 PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
2897 if (tensors_fast == nullptr || PyErr_Occurred()) {
2898 return nullptr;
2899 }
2900 std::vector<tensorflow::int64> augmented_input_ids;
2901 int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
2902 PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
2903 for (Py_ssize_t position = 0; position < len; ++position) {
2904 PyObject* input = tensors_fast_array[position];
2905 if (input == Py_None) {
2906 continue;
2907 }
2908 tensorflow::DataType input_dtype(tensorflow::PyTensor_DataType(input));
2909 if (input_dtype == tensorflow::DT_INVALID) {
2910 return nullptr;
2911 }
2912 augmented_input_ids.push_back(FastTensorId(input));
2913 }
2914 if (PyErr_Occurred()) {
2915 return nullptr;
2916 }
2917 // Find the innermost accumulator such that all outer accumulators are
2918 // recording. Any more deeply nested accumulators will not have their JVPs
2919 // saved.
2920 AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
2921 for (; innermost_all_recording != accumulators.end();
2922 ++innermost_all_recording) {
2923 if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
2924 break;
2925 }
2926 }
2927 AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
2928 innermost_all_recording);
2929
2930 bool saving_jvps = false;
2931 tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
2932 std::vector<PyObject*> new_tensors;
2933 Py_ssize_t accumulator_index = 0;
2934 // Start with the innermost accumulators to give outer accumulators a chance
2935 // to find their higher-order JVPs.
2936 for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
2937 it != accumulators.rend(); ++it, ++accumulator_index) {
2938 std::vector<tensorflow::int64> new_input_ids;
2939 std::vector<std::pair<tensorflow::int64, tensorflow::int64>>
2940 accumulator_indices;
2941 if (it == reverse_innermost_all_recording) {
2942 saving_jvps = true;
2943 }
2944 if (saving_jvps) {
2945 for (int input_index = 0; input_index < augmented_input_ids.size();
2946 ++input_index) {
2947 tensorflow::int64 existing_input = augmented_input_ids[input_index];
2948 PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
2949 if (jvp != nullptr) {
2950 new_tensors.push_back(jvp);
2951 new_input_ids.push_back(FastTensorId(jvp));
2952 accumulator_indices.emplace_back(
2953 input_index,
2954 augmented_input_ids.size() + new_input_ids.size() - 1);
2955 }
2956 }
2957 }
2958 tensorflow::Safe_PyObjectPtr accumulator_indices_py(
2959 PyTuple_New(accumulator_indices.size()));
2960 for (int i = 0; i < accumulator_indices.size(); ++i) {
2961 tensorflow::Safe_PyObjectPtr from_index(
2962 GetPythonObjectFromInt(accumulator_indices[i].first));
2963 tensorflow::Safe_PyObjectPtr to_index(
2964 GetPythonObjectFromInt(accumulator_indices[i].second));
2965 PyTuple_SetItem(accumulator_indices_py.get(), i,
2966 PyTuple_Pack(2, from_index.get(), to_index.get()));
2967 }
2968 PyTuple_SetItem(all_indices.get(), accumulator_index,
2969 accumulator_indices_py.release());
2970 augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
2971 new_input_ids.end());
2972 }
2973
2974 tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
2975 for (int i = 0; i < new_tensors.size(); ++i) {
2976 PyObject* jvp = new_tensors[i];
2977 Py_INCREF(jvp);
2978 PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
2979 }
2980 return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
2981 }
2982
2983 namespace {
2984
2985 // Indices for the "args" tuple that's passed to TFE_Py_FastPathExecute_C.
2986 enum FastPathExecuteArgIndex {
2987 FAST_PATH_EXECUTE_ARG_CONTEXT = 0,
2988 FAST_PATH_EXECUTE_ARG_OP_NAME = 1,
2989 FAST_PATH_EXECUTE_ARG_NAME = 2,
2990 FAST_PATH_EXECUTE_ARG_INPUT_START = 3
2991 };
2992
2993 PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
2994 #if PY_MAJOR_VERSION >= 3
2995 return PyUnicode_FromStringAndSize(s.data(), s.size());
2996 #else
2997 return PyBytes_FromStringAndSize(s.data(), s.size());
2998 #endif
2999 }
3000
3001 bool CheckResourceVariable(PyObject* item) {
3002 if (tensorflow::swig::IsResourceVariable(item)) {
3003 tensorflow::Safe_PyObjectPtr handle(
3004 PyObject_GetAttrString(item, "_handle"));
3005 return EagerTensor_CheckExact(handle.get());
3006 }
3007
3008 return false;
3009 }
3010
3011 bool IsNumberType(PyObject* item) {
3012 #if PY_MAJOR_VERSION >= 3
3013 return PyFloat_Check(item) || PyLong_Check(item);
3014 #else
3015 return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
3016 #endif
3017 }
3018
3019 bool CheckOneInput(PyObject* item) {
3020 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
3021 PyArray_Check(item) || IsNumberType(item)) {
3022 return true;
3023 }
3024
3025 // Sequences are not properly handled. Sequences with purely python numeric
3026 // types work, but sequences with mixes of EagerTensors and python numeric
3027 // types don't work.
3028 // TODO(nareshmodi): fix
3029 return false;
3030 }
3031
3032 bool CheckInputsOk(PyObject* seq, int start_index,
3033 const tensorflow::OpDef& op_def) {
3034 for (int i = 0; i < op_def.input_arg_size(); i++) {
3035 PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
3036 if (!op_def.input_arg(i).number_attr().empty() ||
3037 !op_def.input_arg(i).type_list_attr().empty()) {
3038 // This item should be a seq input.
3039 if (!PySequence_Check(item)) {
3040 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3041 << "\", Input \"" << op_def.input_arg(i).name()
3042 << "\" since we expected a sequence, but got "
3043 << item->ob_type->tp_name;
3044 return false;
3045 }
3046 tensorflow::Safe_PyObjectPtr fast_item(
3047 PySequence_Fast(item, "Could not parse sequence."));
3048 if (fast_item.get() == nullptr) {
3049 return false;
3050 }
3051 int len = PySequence_Fast_GET_SIZE(fast_item.get());
3052 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3053 for (Py_ssize_t j = 0; j < len; j++) {
3054 PyObject* inner_item = fast_item_array[j];
3055 if (!CheckOneInput(inner_item)) {
3056 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
3057 << "\", Input \"" << op_def.input_arg(i).name()
3058 << "\", Index " << j
3059 << " since we expected an EagerTensor/ResourceVariable, "
3060 "but got "
3061 << inner_item->ob_type->tp_name;
3062 return false;
3063 }
3064 }
3065 } else if (!CheckOneInput(item)) {
3066 VLOG(1)
3067 << "Falling back to slow path for Op \"" << op_def.name()
3068 << "\", Input \"" << op_def.input_arg(i).name()
3069 << "\" since we expected an EagerTensor/ResourceVariable, but got "
3070 << item->ob_type->tp_name;
3071 return false;
3072 }
3073 }
3074
3075 return true;
3076 }
3077
3078 tensorflow::DataType MaybeGetDType(PyObject* item) {
3079 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
3080 return tensorflow::PyTensor_DataType(item);
3081 }
3082
3083 return tensorflow::DT_INVALID;
3084 }
3085
3086 tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
3087 FastPathOpExecInfo* op_exec_info) {
3088 auto cached_it = op_exec_info->cached_dtypes.find(attr);
3089 if (cached_it != op_exec_info->cached_dtypes.end()) {
3090 return cached_it->second;
3091 }
3092
3093 auto it = op_exec_info->attr_to_inputs_map->find(attr);
3094 if (it == op_exec_info->attr_to_inputs_map->end()) {
3095 // No other inputs - this should never happen.
3096 return tensorflow::DT_INVALID;
3097 }
3098
3099 for (const auto& input_info : it->second) {
3100 PyObject* item = PyTuple_GET_ITEM(
3101 op_exec_info->args, FAST_PATH_EXECUTE_ARG_INPUT_START + input_info.i);
3102 if (input_info.is_list) {
3103 tensorflow::Safe_PyObjectPtr fast_item(
3104 PySequence_Fast(item, "Unable to allocate"));
3105 int len = PySequence_Fast_GET_SIZE(fast_item.get());
3106 PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
3107 for (int i = 0; i < len; i++) {
3108 auto dtype = MaybeGetDType(fast_item_array[i]);
3109 if (dtype != tensorflow::DT_INVALID) return dtype;
3110 }
3111 } else {
3112 auto dtype = MaybeGetDType(item);
3113 if (dtype != tensorflow::DT_INVALID) return dtype;
3114 }
3115 }
3116
3117 auto default_it = op_exec_info->default_dtypes->find(attr);
3118 if (default_it != op_exec_info->default_dtypes->end()) {
3119 return default_it->second;
3120 }
3121
3122 return tensorflow::DT_INVALID;
3123 }
3124
3125 PyObject* CopySequenceSettingIndicesToNull(
3126 PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
3127 tensorflow::Safe_PyObjectPtr fast_seq(
3128 PySequence_Fast(seq, "unable to allocate"));
3129 int len = PySequence_Fast_GET_SIZE(fast_seq.get());
3130 PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
3131 PyObject* result = PyTuple_New(len);
3132 for (int i = 0; i < len; i++) {
3133 PyObject* item;
3134 if (indices.find(i) != indices.end()) {
3135 item = Py_None;
3136 } else {
3137 item = fast_seq_array[i];
3138 }
3139 Py_INCREF(item);
3140 PyTuple_SET_ITEM(result, i, item);
3141 }
3142 return result;
3143 }
3144
3145 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
3146 PyObject* results,
3147 PyObject* forward_pass_name_scope = nullptr) {
3148 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
3149 if (PyErr_Occurred()) return nullptr;
3150 std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
3151 if (PyErr_Occurred()) return nullptr;
3152
3153 bool should_record = false;
3154 for (TFE_Py_Tape* tape : SafeTapeSet()) {
3155 if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
3156 should_record = true;
3157 break;
3158 }
3159 }
3160 if (!should_record) {
3161 for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
3162 if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
3163 should_record = true;
3164 break;
3165 }
3166 }
3167 }
3168 if (!should_record) Py_RETURN_NONE;
3169
3170 string c_op_name = TFE_GetPythonString(op_name);
3171
3172 PyObject* op_outputs;
3173 bool op_outputs_tuple_created = false;
3174
3175 if (const auto unused_output_indices =
3176 OpGradientUnusedOutputIndices(c_op_name)) {
3177 if (unused_output_indices->empty()) {
3178 op_outputs = Py_None;
3179 } else {
3180 op_outputs_tuple_created = true;
3181 op_outputs =
3182 CopySequenceSettingIndicesToNull(results, *unused_output_indices);
3183 }
3184 } else {
3185 op_outputs = results;
3186 }
3187
3188 PyObject* op_inputs;
3189 bool op_inputs_tuple_created = false;
3190
3191 if (const auto unused_input_indices =
3192 OpGradientUnusedInputIndices(c_op_name)) {
3193 if (unused_input_indices->empty()) {
3194 op_inputs = Py_None;
3195 } else {
3196 op_inputs_tuple_created = true;
3197 op_inputs =
3198 CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
3199 }
3200 } else {
3201 op_inputs = inputs;
3202 }
3203
3204 tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
3205 [op_name, attrs, inputs, results](
3206 const std::vector<PyObject*>& input_tangents,
3207 std::vector<PyObject*>* output_tangents, bool use_batch) {
3208 return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
3209 output_tangents, use_batch);
3210 });
3211 tensorflow::eager::ForwardFunction<PyObject>* forward_function;
3212 if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
3213 c_op_name == "If" || c_op_name == "StatelessIf") {
3214 // Control flow contains non-hashable attributes. Handling them in Python is
3215 // a headache, so instead we'll stay as close to GradientTape's handling as
3216 // possible (a null forward function means the accumulator forwards to a
3217 // tape).
3218 //
3219 // This is safe to do since we'll only see control flow when graph building,
3220 // in which case we can rely on pruning.
3221 forward_function = nullptr;
3222 } else {
3223 forward_function = &py_forward_function;
3224 }
3225
3226 PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
3227
3228 if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
3229
3230 TapeSetRecordOperation(
3231 op_name, inputs, results, input_ids, input_dtypes,
3232 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3233 forward_pass_name_scope]() {
3234 Py_INCREF(op_name);
3235 Py_INCREF(attrs);
3236 Py_INCREF(num_inputs);
3237 Py_INCREF(op_inputs);
3238 Py_INCREF(op_outputs);
3239 Py_INCREF(forward_pass_name_scope);
3240 PyBackwardFunction* function = new PyBackwardFunction(
3241 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3242 forward_pass_name_scope](
3243 PyObject* output_grads,
3244 const std::vector<tensorflow::int64>& unneeded_gradients) {
3245 if (PyErr_Occurred()) {
3246 return static_cast<PyObject*>(nullptr);
3247 }
3248 tensorflow::Safe_PyObjectPtr skip_input_indices;
3249 if (!unneeded_gradients.empty()) {
3250 skip_input_indices.reset(
3251 PyTuple_New(unneeded_gradients.size()));
3252 for (int i = 0; i < unneeded_gradients.size(); i++) {
3253 PyTuple_SET_ITEM(
3254 skip_input_indices.get(), i,
3255 GetPythonObjectFromInt(unneeded_gradients[i]));
3256 }
3257 } else {
3258 Py_INCREF(Py_None);
3259 skip_input_indices.reset(Py_None);
3260 }
3261 tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3262 "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3263 output_grads, skip_input_indices.get(),
3264 forward_pass_name_scope));
3265
3266 tensorflow::Safe_PyObjectPtr result(
3267 PyObject_CallObject(gradient_function, callback_args.get()));
3268
3269 if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
3270
3271 return tensorflow::swig::Flatten(result.get());
3272 });
3273 return function;
3274 },
3275 [op_name, attrs, num_inputs, op_inputs, op_outputs,
3276 forward_pass_name_scope](PyBackwardFunction* backward_function) {
3277 Py_DECREF(op_name);
3278 Py_DECREF(attrs);
3279 Py_DECREF(num_inputs);
3280 Py_DECREF(op_inputs);
3281 Py_DECREF(op_outputs);
3282 Py_DECREF(forward_pass_name_scope);
3283
3284 delete backward_function;
3285 },
3286 forward_function);
3287
3288 Py_DECREF(num_inputs);
3289 if (op_outputs_tuple_created) Py_DECREF(op_outputs);
3290 if (op_inputs_tuple_created) Py_DECREF(op_inputs);
3291
3292 if (PyErr_Occurred()) {
3293 return nullptr;
3294 }
3295
3296 Py_RETURN_NONE;
3297 }
3298
3299 void MaybeNotifyVariableAccessed(PyObject* input) {
3300 DCHECK(CheckResourceVariable(input));
3301 DCHECK(PyObject_HasAttrString(input, "_trainable"));
3302
3303 tensorflow::Safe_PyObjectPtr trainable(
3304 PyObject_GetAttrString(input, "_trainable"));
3305 if (trainable.get() == Py_False) return;
3306 TFE_Py_TapeVariableAccessed(input);
3307 TFE_Py_VariableWatcherVariableAccessed(input);
3308 }
3309
3310 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
3311 PyObject* input, tensorflow::Safe_PyObjectPtr* output,
3312 TF_Status* status) {
3313 MaybeNotifyVariableAccessed(input);
3314
3315 TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
3316 auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
3317 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3318
3319 TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
3320 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3321
3322 // Set dtype
3323 DCHECK(PyObject_HasAttrString(input, "_dtype"));
3324 tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
3325 int value;
3326 if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
3327 return false;
3328 }
3329 TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
3330
3331 // Get handle
3332 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
3333 if (!EagerTensor_CheckExact(handle.get())) return false;
3334 TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
3335 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3336
3337 int num_retvals = 1;
3338 TFE_TensorHandle* output_handle;
3339 TFE_Execute(op, &output_handle, &num_retvals, status);
3340 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
3341
3342 // Always create the py object (and correctly DECREF it) from the returned
3343 // value, else the data will leak.
3344 output->reset(EagerTensorFromHandle(output_handle));
3345
3346 // TODO(nareshmodi): Should we run post exec callbacks here?
3347 if (parent_op_exec_info.run_gradient_callback) {
3348 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
3349 PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
3350
3351 tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
3352 Py_INCREF(output->get()); // stay alive after since tuple steals.
3353 PyTuple_SET_ITEM(outputs.get(), 0, output->get());
3354
3355 tensorflow::Safe_PyObjectPtr op_string(
3356 GetPythonObjectFromString("ReadVariableOp"));
3357 if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
3358 outputs.get())) {
3359 return false;
3360 }
3361 }
3362
3363 return true;
3364 }
3365
3366 // Supports 3 cases at the moment:
3367 // i) input is an EagerTensor.
3368 // ii) input is a ResourceVariable - in this case, the is_variable param is
3369 // set to true.
3370 // iii) input is an arbitrary python list/tuple (note, this handling doesn't
3371 // support packing).
3372 //
3373 // NOTE: dtype_hint_getter must *always* return a PyObject that can be
3374 // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
3375 // increfs Py_None).
3376 //
3377 // NOTE: This function sets a python error directly, and returns false.
3378 // TF_Status is only passed since we don't want to have to reallocate it.
3379 bool ConvertToTensor(
3380 const FastPathOpExecInfo& op_exec_info, PyObject* input,
3381 tensorflow::Safe_PyObjectPtr* output_handle,
3382 // This gets a hint for this particular input.
3383 const std::function<tensorflow::DataType()>& dtype_hint_getter,
3384 // This sets the dtype after conversion is complete.
3385 const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
3386 TF_Status* status) {
3387 if (EagerTensor_CheckExact(input)) {
3388 Py_INCREF(input);
3389 output_handle->reset(input);
3390 return true;
3391 } else if (CheckResourceVariable(input)) {
3392 return ReadVariableOp(op_exec_info, input, output_handle, status);
3393 }
3394
3395 // The hint comes from a supposedly similarly typed tensor.
3396 tensorflow::DataType dtype_hint = dtype_hint_getter();
3397
3398 TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
3399 op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
3400 if (handle == nullptr) {
3401 return MaybeRaiseExceptionFromTFStatus(status, nullptr);
3402 }
3403
3404 output_handle->reset(EagerTensorFromHandle(handle));
3405 dtype_setter(
3406 static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
3407
3408 return true;
3409 }
3410
3411 // Adds input and type attr to the op, and to the list of flattened
3412 // inputs/attrs.
3413 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
3414 const bool add_type_attr,
3415 const tensorflow::OpDef::ArgDef& input_arg,
3416 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
3417 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
3418 TFE_Op* op, TF_Status* status) {
3419 // py_eager_tensor's ownership is transferred to flattened_inputs if it is
3420 // required, else the object is destroyed and DECREF'd when the object goes
3421 // out of scope in this function.
3422 tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
3423
3424 if (!ConvertToTensor(
3425 *op_exec_info, input, &py_eager_tensor,
3426 [&]() {
3427 if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
3428 return input_arg.type();
3429 }
3430 return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
3431 },
3432 [&](const tensorflow::DataType dtype) {
3433 op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
3434 },
3435 status)) {
3436 return false;
3437 }
3438
3439 TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
3440
3441 if (add_type_attr && !input_arg.type_attr().empty()) {
3442 auto dtype = TFE_TensorHandleDataType(input_handle);
3443 TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
3444 if (flattened_attrs != nullptr) {
3445 flattened_attrs->emplace_back(
3446 GetPythonObjectFromString(input_arg.type_attr()));
3447 flattened_attrs->emplace_back(PyLong_FromLong(dtype));
3448 }
3449 }
3450
3451 if (flattened_inputs != nullptr) {
3452 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3453 }
3454
3455 TFE_OpAddInput(op, input_handle, status);
3456 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3457 return false;
3458 }
3459
3460 return true;
3461 }
3462
3463 const char* GetDeviceName(PyObject* py_device_name) {
3464 if (py_device_name != Py_None) {
3465 return TFE_GetPythonString(py_device_name);
3466 }
3467 return nullptr;
3468 }
3469
3470 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
3471 if (!PySequence_Check(seq)) {
3472 PyErr_SetString(PyExc_TypeError,
3473 Printf("expected a sequence for attr %s, got %s instead",
3474 attr_name.data(), seq->ob_type->tp_name)
3475 .data());
3476
3477 return false;
3478 }
3479 if (PyArray_Check(seq) &&
3480 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
3481 PyErr_SetString(PyExc_ValueError,
3482 Printf("expected a sequence for attr %s, got an ndarray "
3483 "with rank %d instead",
3484 attr_name.data(),
3485 PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
3486 .data());
3487 return false;
3488 }
3489 return true;
3490 }
3491
3492 bool RunCallbacks(
3493 const FastPathOpExecInfo& op_exec_info, PyObject* args,
3494 int num_inferred_attrs,
3495 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
3496 const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
3497 PyObject* flattened_result) {
3498 DCHECK(op_exec_info.run_callbacks);
3499
3500 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
3501 for (int i = 0; i < flattened_inputs.size(); i++) {
3502 PyObject* input = flattened_inputs[i].get();
3503 Py_INCREF(input);
3504 PyTuple_SET_ITEM(inputs.get(), i, input);
3505 }
3506
3507 int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
3508 int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
3509 tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
3510
3511 for (int i = 0; i < num_non_inferred_attrs; i++) {
3512 auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
3513 Py_INCREF(attr);
3514 PyTuple_SET_ITEM(attrs.get(), i, attr);
3515 }
3516
3517 for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
3518 PyObject* attr_or_name =
3519 flattened_attrs.at(i - num_non_inferred_attrs).get();
3520 Py_INCREF(attr_or_name);
3521 PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
3522 }
3523
3524 if (op_exec_info.run_gradient_callback) {
3525 if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
3526 flattened_result)) {
3527 return false;
3528 }
3529 }
3530
3531 if (op_exec_info.run_post_exec_callbacks) {
3532 tensorflow::Safe_PyObjectPtr callback_args(
3533 Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
3534 flattened_result, op_exec_info.name));
3535 for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
3536 PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
3537 if (!PyCallable_Check(callback_fn)) {
3538 PyErr_SetString(
3539 PyExc_TypeError,
3540 Printf("expected a function for "
3541 "post execution callback in index %ld, got %s instead",
3542 i, callback_fn->ob_type->tp_name)
3543 .c_str());
3544 return false;
3545 }
3546 PyObject* callback_result =
3547 PyObject_CallObject(callback_fn, callback_args.get());
3548 if (!callback_result) {
3549 return false;
3550 }
3551 Py_DECREF(callback_result);
3552 }
3553 }
3554
3555 return true;
3556 }
3557
3558 } // namespace
3559
3560 PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
3561 tensorflow::profiler::TraceMe activity(
3562 "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
3563 Py_ssize_t args_size = PyTuple_GET_SIZE(args);
3564 if (args_size < FAST_PATH_EXECUTE_ARG_INPUT_START) {
3565 PyErr_SetString(
3566 PyExc_ValueError,
3567 Printf("There must be at least %d items in the input tuple.",
3568 FAST_PATH_EXECUTE_ARG_INPUT_START)
3569 .c_str());
3570 return nullptr;
3571 }
3572
3573 FastPathOpExecInfo op_exec_info;
3574
3575 PyObject* py_eager_context =
3576 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_CONTEXT);
3577
3578 // TODO(edoper): Use interned string here
3579 PyObject* eager_context_handle =
3580 PyObject_GetAttrString(py_eager_context, "_context_handle");
3581
3582 TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
3583 PyCapsule_GetPointer(eager_context_handle, nullptr));
3584 op_exec_info.ctx = ctx;
3585 op_exec_info.args = args;
3586
3587 if (ctx == nullptr) {
3588 // The context hasn't been initialized. It will be in the slow path.
3589 RaiseFallbackException(
3590 "This function does not handle the case of the path where "
3591 "all inputs are not already EagerTensors.");
3592 return nullptr;
3593 }
3594
3595 auto* tld = tensorflow::GetEagerContextThreadLocalData(py_eager_context);
3596 if (tld == nullptr) {
3597 return nullptr;
3598 }
3599 op_exec_info.device_name = GetDeviceName(tld->device_name.get());
3600 op_exec_info.callbacks = tld->op_callbacks.get();
3601
3602 op_exec_info.op_name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_OP_NAME);
3603 op_exec_info.name = PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_NAME);
3604
3605 // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
3606 // (similar to benchmark_tf_gradient_function_*). Also consider using an
3607 // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
3608 // point out problems with heap allocs.
3609 op_exec_info.run_gradient_callback =
3610 !*ThreadTapeIsStopped() && HasAccumulatorOrTape();
3611 op_exec_info.run_post_exec_callbacks =
3612 op_exec_info.callbacks != Py_None &&
3613 PyList_Size(op_exec_info.callbacks) > 0;
3614 op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
3615 op_exec_info.run_post_exec_callbacks;
3616
3617 TF_Status* status = GetStatus();
3618 const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
3619 if (op_name == nullptr) {
3620 PyErr_SetString(PyExc_TypeError,
3621 Printf("expected a string for op_name, got %s instead",
3622 op_exec_info.op_name->ob_type->tp_name)
3623 .c_str());
3624 return nullptr;
3625 }
3626
3627 TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
3628
3629 auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
3630 ReturnStatus(status);
3631 ReturnOp(ctx, op);
3632 });
3633
3634 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3635 return nullptr;
3636 }
3637
3638 tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace(
3639 tensorflow::StackTrace::kStackTraceInitialSize));
3640
3641 const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef();
3642 if (op_def == nullptr) return nullptr;
3643
3644 if (args_size <
3645 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size()) {
3646 PyErr_SetString(
3647 PyExc_ValueError,
3648 Printf("Tuple size smaller than intended. Expected to be at least %d, "
3649 "was %ld",
3650 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3651 args_size)
3652 .c_str());
3653 return nullptr;
3654 }
3655
3656 if (!CheckInputsOk(args, FAST_PATH_EXECUTE_ARG_INPUT_START, *op_def)) {
3657 RaiseFallbackException(
3658 "This function does not handle the case of the path where "
3659 "all inputs are not already EagerTensors.");
3660 return nullptr;
3661 }
3662
3663 op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
3664 op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
3665
3666 // Mapping of attr name to size - used to calculate the number of values
3667 // to be expected by the TFE_Execute run.
3668 tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
3669
3670 // Set non-inferred attrs, including setting defaults if the attr is passed in
3671 // as None.
3672 for (int i = FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size();
3673 i < args_size; i += 2) {
3674 PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
3675 const char* attr_name = TFE_GetPythonString(py_attr_name);
3676 PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
3677
3678 // Not creating an index since most of the time there are not more than a
3679 // few attrs.
3680 // TODO(nareshmodi): Maybe include the index as part of the
3681 // OpRegistrationData.
3682 for (const auto& attr : op_def->attr()) {
3683 if (tensorflow::StringPiece(attr_name) == attr.name()) {
3684 SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
3685 &attr_list_sizes, status);
3686
3687 if (!status->status.ok()) {
3688 VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
3689 << "\" since we are unable to set the value for attr \""
3690 << attr.name() << "\" due to: " << TF_Message(status);
3691 RaiseFallbackException(TF_Message(status));
3692 return nullptr;
3693 }
3694
3695 break;
3696 }
3697 }
3698 }
3699
3700 // Flat attrs and inputs as required by the record_gradient call. The attrs
3701 // here only contain inferred attrs (non-inferred attrs are added directly
3702 // from the input args).
3703 // All items in flattened_attrs and flattened_inputs contain
3704 // Safe_PyObjectPtr - any time something steals a reference to this, it must
3705 // INCREF.
3706 // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
3707 // directly.
3708 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
3709 nullptr;
3710 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
3711 nullptr;
3712
3713 // TODO(nareshmodi): Encapsulate callbacks information into a struct.
3714 if (op_exec_info.run_callbacks) {
3715 flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3716 flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
3717 }
3718
3719 // Add inferred attrs and inputs.
3720 // The following code might set duplicate type attrs. This will result in
3721 // the CacheKey for the generated AttrBuilder possibly differing from
3722 // those where the type attrs are correctly set. Inconsistent CacheKeys
3723 // for ops means that there might be unnecessarily duplicated kernels.
3724 // TODO(nareshmodi): Fix this.
3725 for (int i = 0; i < op_def->input_arg_size(); i++) {
3726 const auto& input_arg = op_def->input_arg(i);
3727
3728 PyObject* input =
3729 PyTuple_GET_ITEM(args, FAST_PATH_EXECUTE_ARG_INPUT_START + i);
3730 if (!input_arg.number_attr().empty()) {
3731 // The item is a homogeneous list.
3732 if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
3733 tensorflow::Safe_PyObjectPtr fast_input(
3734 PySequence_Fast(input, "Could not parse sequence."));
3735 if (fast_input.get() == nullptr) {
3736 return nullptr;
3737 }
3738 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3739 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3740
3741 TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
3742 if (op_exec_info.run_callbacks) {
3743 flattened_attrs->emplace_back(
3744 GetPythonObjectFromString(input_arg.number_attr()));
3745 flattened_attrs->emplace_back(PyLong_FromLong(len));
3746 }
3747 attr_list_sizes[input_arg.number_attr()] = len;
3748
3749 if (len > 0) {
3750 // First item adds the type attr.
3751 if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
3752 flattened_attrs.get(), flattened_inputs.get(), op,
3753 status)) {
3754 return nullptr;
3755 }
3756
3757 for (Py_ssize_t j = 1; j < len; j++) {
3758 // Since the list is homogeneous, we don't need to re-add the attr.
3759 if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
3760 input_arg, nullptr /* flattened_attrs */,
3761 flattened_inputs.get(), op, status)) {
3762 return nullptr;
3763 }
3764 }
3765 }
3766 } else if (!input_arg.type_list_attr().empty()) {
3767 // The item is a heterogeneous list.
3768 if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
3769 return nullptr;
3770 }
3771 tensorflow::Safe_PyObjectPtr fast_input(
3772 PySequence_Fast(input, "Could not parse sequence."));
3773 if (fast_input.get() == nullptr) {
3774 return nullptr;
3775 }
3776 const string& attr_name = input_arg.type_list_attr();
3777 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
3778 PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
3779 tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
3780 PyObject* py_attr_value = nullptr;
3781 if (op_exec_info.run_callbacks) {
3782 py_attr_value = PyTuple_New(len);
3783 }
3784 for (Py_ssize_t j = 0; j < len; j++) {
3785 PyObject* py_input = fast_input_array[j];
3786 tensorflow::Safe_PyObjectPtr py_eager_tensor;
3787 if (!ConvertToTensor(
3788 op_exec_info, py_input, &py_eager_tensor,
3789 []() { return tensorflow::DT_INVALID; },
3790 [](const tensorflow::DataType dtype) {}, status)) {
3791 return nullptr;
3792 }
3793
3794 TFE_TensorHandle* input_handle =
3795 EagerTensor_Handle(py_eager_tensor.get());
3796
3797 attr_value[j] = TFE_TensorHandleDataType(input_handle);
3798
3799 TFE_OpAddInput(op, input_handle, status);
3800 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
3801 return nullptr;
3802 }
3803
3804 if (op_exec_info.run_callbacks) {
3805 flattened_inputs->emplace_back(std::move(py_eager_tensor));
3806
3807 PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
3808 }
3809 }
3810 if (op_exec_info.run_callbacks) {
3811 flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
3812 flattened_attrs->emplace_back(py_attr_value);
3813 }
3814 TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
3815 attr_value.size());
3816 attr_list_sizes[attr_name] = len;
3817 } else {
3818 // The item is a single item.
3819 if (!AddInputToOp(&op_exec_info, input, true, input_arg,
3820 flattened_attrs.get(), flattened_inputs.get(), op,
3821 status)) {
3822 return nullptr;
3823 }
3824 }
3825 }
3826
3827 int64_t num_outputs = 0;
3828 for (int i = 0; i < op_def->output_arg_size(); i++) {
3829 const auto& output_arg = op_def->output_arg(i);
3830 int64_t delta = 1;
3831 if (!output_arg.number_attr().empty()) {
3832 delta = attr_list_sizes[output_arg.number_attr()];
3833 } else if (!output_arg.type_list_attr().empty()) {
3834 delta = attr_list_sizes[output_arg.type_list_attr()];
3835 }
3836 if (delta < 0) {
3837 RaiseFallbackException(
3838 "Attributes suggest that the size of an output list is less than 0");
3839 return nullptr;
3840 }
3841 num_outputs += delta;
3842 }
3843
3844 // If number of retvals is larger than int32, we error out.
3845 if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
3846 PyErr_SetString(
3847 PyExc_ValueError,
3848 Printf("Number of outputs is too big: %ld", num_outputs).c_str());
3849 return nullptr;
3850 }
3851 int num_retvals = num_outputs;
3852
3853 tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
3854
3855 Py_BEGIN_ALLOW_THREADS;
3856 TFE_Execute(op, retvals.data(), &num_retvals, status);
3857 Py_END_ALLOW_THREADS;
3858
3859 if (!status->status.ok()) {
3860 // Augment the status with the op_name for easier debugging similar to
3861 // TFE_Py_Execute.
3862 std::vector<tensorflow::StackFrame> stack_trace =
3863 status->status.stack_trace();
3864 status->status = tensorflow::Status(
3865 status->status.code(),
3866 tensorflow::strings::StrCat(
3867 TF_Message(status),
3868 " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]"),
3869 std::move(stack_trace));
3870
3871 MaybeRaiseExceptionFromTFStatus(status, nullptr);
3872 return nullptr;
3873 }
3874
3875 tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
3876 for (int i = 0; i < num_retvals; ++i) {
3877 PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
3878 }
3879
3880 if (op_exec_info.run_callbacks) {
3881 if (!RunCallbacks(
3882 op_exec_info, args,
3883 FAST_PATH_EXECUTE_ARG_INPUT_START + op_def->input_arg_size(),
3884 *flattened_inputs, *flattened_attrs, flat_result.get())) {
3885 return nullptr;
3886 }
3887 }
3888
3889 // Unflatten results.
3890 if (op_def->output_arg_size() == 0) {
3891 Py_RETURN_NONE;
3892 }
3893
3894 if (op_def->output_arg_size() == 1) {
3895 if (!op_def->output_arg(0).number_attr().empty() ||
3896 !op_def->output_arg(0).type_list_attr().empty()) {
3897 return flat_result.release();
3898 } else {
3899 auto* result = PyList_GET_ITEM(flat_result.get(), 0);
3900 Py_INCREF(result);
3901 return result;
3902 }
3903 }
3904
3905 // Correctly output the results that are made into a namedtuple.
3906 PyObject* result = PyList_New(op_def->output_arg_size());
3907 int flat_result_index = 0;
3908 for (int i = 0; i < op_def->output_arg_size(); i++) {
3909 if (!op_def->output_arg(i).number_attr().empty()) {
3910 int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
3911 PyObject* inner_list = PyList_New(list_length);
3912 for (int j = 0; j < list_length; j++) {
3913 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3914 Py_INCREF(obj);
3915 PyList_SET_ITEM(inner_list, j, obj);
3916 }
3917 PyList_SET_ITEM(result, i, inner_list);
3918 } else if (!op_def->output_arg(i).type_list_attr().empty()) {
3919 int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
3920 PyObject* inner_list = PyList_New(list_length);
3921 for (int j = 0; j < list_length; j++) {
3922 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3923 Py_INCREF(obj);
3924 PyList_SET_ITEM(inner_list, j, obj);
3925 }
3926 PyList_SET_ITEM(result, i, inner_list);
3927 } else {
3928 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
3929 Py_INCREF(obj);
3930 PyList_SET_ITEM(result, i, obj);
3931 }
3932 }
3933 return result;
3934 }
3935
3936 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
3937 PyObject* attrs, PyObject* results,
3938 PyObject* forward_pass_name_scope) {
3939 if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
3940 Py_RETURN_NONE;
3941 }
3942
3943 return RecordGradient(op_name, inputs, attrs, results,
3944 forward_pass_name_scope);
3945 }
3946
3947 namespace {
3948 const char kTensor[] = "T";
3949 const char kList[] = "L";
3950 const char kListEnd[] = "l";
3951 const char kTuple[] = "U";
3952 const char kTupleEnd[] = "u";
3953 const char kDict[] = "D";
3954 const char kRaw[] = "R";
3955 const char kShape[] = "s";
3956 const char kShapeDelim[] = "-";
3957 const char kDType[] = "d";
3958 const char kNone[] = "n";
3959 const char kCompositeTensor[] = "C";
3960 const char kAttrs[] = "A";
3961 const char kAttrsEnd[] = "a";
3962
3963 struct EncodeResult {
3964 string str;
3965 std::vector<PyObject*> objects;
3966
3967 PyObject* ToPyTuple() {
3968 PyObject* result = PyTuple_New(2);
3969
3970 PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str));
3971
3972 if (objects.empty()) {
3973 Py_INCREF(Py_None);
3974 PyTuple_SET_ITEM(result, 1, Py_None);
3975 } else {
3976 PyObject* objects_tuple = PyTuple_New(objects.size());
3977
3978 for (int i = 0; i < objects.size(); i++) {
3979 PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
3980 }
3981
3982 PyTuple_SET_ITEM(result, 1, objects_tuple);
3983 }
3984
3985 return result;
3986 }
3987 };
3988
3989 tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
3990 bool include_tensor_ranks_only,
3991 EncodeResult* result) {
3992 if (EagerTensor_CheckExact(arg)) {
3993 tensorflow::ImmediateExecutionTensorHandle* handle =
3994 tensorflow::unwrap(EagerTensor_Handle(arg));
3995
3996 absl::StrAppend(&result->str, kDType,
3997 static_cast<tensorflow::DataType>(handle->DataType()));
3998 absl::StrAppend(&result->str, kShape);
3999
4000 int num_dims;
4001 tensorflow::Status status = handle->NumDims(&num_dims);
4002 if (!status.ok()) return status;
4003
4004 if (include_tensor_ranks_only) {
4005 absl::StrAppend(&result->str, num_dims);
4006 } else {
4007 for (int i = 0; i < num_dims; ++i) {
4008 tensorflow::int64 dim_size;
4009 status = handle->Dim(i, &dim_size);
4010 if (!status.ok()) return status;
4011 absl::StrAppend(&result->str, dim_size, kShapeDelim);
4012 }
4013 }
4014 return tensorflow::Status::OK();
4015 }
4016
4017 tensorflow::Safe_PyObjectPtr dtype_object(
4018 PyObject_GetAttrString(arg, "dtype"));
4019
4020 if (dtype_object == nullptr) {
4021 return tensorflow::errors::InvalidArgument(
4022 "ops.Tensor object doesn't have dtype() attr.");
4023 }
4024
4025 tensorflow::Safe_PyObjectPtr dtype_enum(
4026 PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
4027
4028 if (dtype_enum == nullptr) {
4029 return tensorflow::errors::InvalidArgument(
4030 "ops.Tensor's dtype object doesn't have _type_enum() attr.");
4031 }
4032
4033 tensorflow::DataType dtype =
4034 static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
4035
4036 absl::StrAppend(&result->str, kDType, dtype);
4037
4038 static char _shape_tuple[] = "_shape_tuple";
4039 tensorflow::Safe_PyObjectPtr shape_tuple(
4040 PyObject_CallMethod(arg, _shape_tuple, nullptr));
4041
4042 if (shape_tuple == nullptr) {
4043 return tensorflow::errors::InvalidArgument(
4044 "ops.Tensor object doesn't have _shape_tuple() method.");
4045 }
4046
4047 if (shape_tuple.get() == Py_None) {
4048 // Unknown shape, encode that directly.
4049 absl::StrAppend(&result->str, kNone);
4050 return tensorflow::Status::OK();
4051 }
4052
4053 absl::StrAppend(&result->str, kShape);
4054 tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
4055 shape_tuple.get(), "shape_tuple didn't return a sequence"));
4056
4057 int len = PySequence_Fast_GET_SIZE(shape_seq.get());
4058 PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get());
4059
4060 if (include_tensor_ranks_only) {
4061 absl::StrAppend(&result->str, len);
4062 } else {
4063 for (int i = 0; i < len; ++i) {
4064 PyObject* item = shape_seq_array[i];
4065 if (item == Py_None) {
4066 absl::StrAppend(&result->str, kNone);
4067 } else {
4068 absl::StrAppend(&result->str, MakeInt(item));
4069 }
4070 }
4071 }
4072 return tensorflow::Status::OK();
4073 }
4074
4075 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
4076 bool include_tensor_ranks_only,
4077 EncodeResult* result);
4078
4079 // This function doesn't set the type of sequence before
4080 tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
4081 const char* end_type,
4082 bool include_tensor_ranks_only,
4083 EncodeResult* result) {
4084 tensorflow::Safe_PyObjectPtr arg_seq(
4085 PySequence_Fast(arg, "unable to create seq from list/tuple"));
4086
4087 absl::StrAppend(&result->str, type);
4088 int len = PySequence_Fast_GET_SIZE(arg_seq.get());
4089 PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get());
4090 for (int i = 0; i < len; ++i) {
4091 PyObject* item = arg_seq_array[i];
4092 if (item == Py_None) {
4093 absl::StrAppend(&result->str, kNone);
4094 } else {
4095 TF_RETURN_IF_ERROR(
4096 TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
4097 }
4098 }
4099 absl::StrAppend(&result->str, end_type);
4100
4101 return tensorflow::Status::OK();
4102 }
4103
4104 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
4105 bool include_tensor_ranks_only,
4106 EncodeResult* result) {
4107 if (tensorflow::swig::IsTensor(arg)) {
4108 absl::StrAppend(&result->str, kTensor);
4109 TF_RETURN_IF_ERROR(
4110 TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
4111 } else if (PyList_Check(arg)) {
4112 TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
4113 arg, kList, kListEnd, include_tensor_ranks_only, result));
4114 } else if (tensorflow::swig::IsTuple(arg)) {
4115 TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
4116 arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
4117 } else if (tensorflow::swig::IsMapping(arg)) {
4118 tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
4119 if (PyList_Sort(keys.get()) == -1) {
4120 return tensorflow::errors::Internal("Unable to sort keys");
4121 }
4122
4123 absl::StrAppend(&result->str, kDict);
4124 int len = PyList_Size(keys.get());
4125
4126 for (int i = 0; i < len; i++) {
4127 PyObject* key = PyList_GetItem(keys.get(), i);
4128 TF_RETURN_IF_ERROR(
4129 TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
4130 tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
4131 TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
4132 value.get(), include_tensor_ranks_only, result));
4133 }
4134 } else if (tensorflow::swig::IsCompositeTensor(arg)) {
4135 absl::StrAppend(&result->str, kCompositeTensor);
4136
4137 // Add the typespec to the list of objects. (Do *not* use a weakref,
4138 // since the type spec is often a temporary object.)
4139 PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
4140 if (type_spec == nullptr) {
4141 return tensorflow::errors::InvalidArgument(
4142 "Error while reading CompositeTensor._type_spec.");
4143 }
4144 result->objects.push_back(type_spec);
4145 } else if (tensorflow::swig::IsTypeSpec(arg)) {
4146 // Add the typespec (not a weakref) in case it's a temporary object.
4147 absl::StrAppend(&result->str, kRaw);
4148 Py_INCREF(arg);
4149 result->objects.push_back(arg);
4150 } else if (tensorflow::swig::IsAttrs(arg)) {
4151 absl::StrAppend(&result->str, kAttrs);
4152 tensorflow::Safe_PyObjectPtr attrs(
4153 PyObject_GetAttrString(arg, "__attrs_attrs__"));
4154 tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get()));
4155 for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item;
4156 item.reset(PyIter_Next(iter.get()))) {
4157 tensorflow::Safe_PyObjectPtr name(
4158 PyObject_GetAttrString(item.get(), "name"));
4159 tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get()));
4160 TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
4161 attr_arg.get(), include_tensor_ranks_only, result));
4162 }
4163 absl::StrAppend(&result->str, kAttrsEnd);
4164 } else {
4165 PyObject* object = PyWeakref_NewRef(arg, nullptr);
4166
4167 if (object == nullptr) {
4168 PyErr_Clear();
4169
4170 object = arg;
4171 Py_INCREF(object);
4172 }
4173
4174 absl::StrAppend(&result->str, kRaw);
4175 result->objects.push_back(object);
4176 }
4177
4178 return tensorflow::Status::OK();
4179 }
4180
4181 } // namespace
4182
4183 // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
4184 // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
4185 // are used for both performance reasons, as much TensorFlow code specializes
4186 // on known shapes to produce slimmer graphs, and correctness, as some
4187 // high-level APIs require shapes to be fully-known.
4188 //
4189 // `include_tensor_ranks_only` allows caching on arguments excluding shape info,
4190 // so that a slow path using relaxed shape can rely on a cache key that excludes
4191 // shapes.
4192 PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
4193 EncodeResult result;
4194 const auto status =
4195 TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
4196 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
4197 return nullptr;
4198 }
4199
4200 return result.ToPyTuple();
4201 }
4202
4203 // A method prints incoming messages directly to Python's
4204 // stdout using Python's C API. This is necessary in Jupyter notebooks
4205 // and colabs where messages to the C stdout don't go to the notebook
4206 // cell outputs, but calls to Python's stdout do.
4207 void PrintToPythonStdout(const char* msg) {
4208 if (Py_IsInitialized()) {
4209 PyGILState_STATE py_threadstate;
4210 py_threadstate = PyGILState_Ensure();
4211
4212 string string_msg = msg;
4213 // PySys_WriteStdout truncates strings over 1000 bytes, so
4214 // we write the message in chunks small enough to not be truncated.
4215 int CHUNK_SIZE = 900;
4216 auto len = string_msg.length();
4217 for (int i = 0; i < len; i += CHUNK_SIZE) {
4218 PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
4219 }
4220
4221 // Force flushing to make sure print newlines aren't interleaved in
4222 // some colab environments
4223 PyRun_SimpleString("import sys; sys.stdout.flush()");
4224
4225 PyGILState_Release(py_threadstate);
4226 }
4227 }
4228
4229 // Register PrintToPythonStdout as a log listener, to allow
4230 // printing in colabs and jupyter notebooks to work.
4231 void TFE_Py_EnableInteractivePythonLogging() {
4232 static bool enabled_interactive_logging = false;
4233 if (!enabled_interactive_logging) {
4234 enabled_interactive_logging = true;
4235 TF_RegisterLogListener(PrintToPythonStdout);
4236 }
4237 }
4238
4239 namespace {
4240 // weak reference to Python Context object currently active
4241 PyObject* weak_eager_context = nullptr;
4242 } // namespace
4243
4244 PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
4245 Py_XDECREF(weak_eager_context);
4246 weak_eager_context = PyWeakref_NewRef(py_context, nullptr);
4247 if (weak_eager_context == nullptr) {
4248 return nullptr;
4249 }
4250 Py_RETURN_NONE;
4251 }
4252
4253 PyObject* GetPyEagerContext() {
4254 if (weak_eager_context == nullptr) {
4255 PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
4256 return nullptr;
4257 }
4258 PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context);
4259 if (py_context == Py_None) {
4260 PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed");
4261 return nullptr;
4262 }
4263 Py_INCREF(py_context);
4264 return py_context;
4265 }
4266
4267 namespace {
4268
4269 // Default values for thread_local_data fields.
4270 struct EagerContextThreadLocalDataDefaults {
4271 tensorflow::Safe_PyObjectPtr is_eager;
4272 tensorflow::Safe_PyObjectPtr device_spec;
4273 };
4274
4275 // Maps each py_eager_context object to its thread_local_data.
4276 //
4277 // Note: we need to use the python Context object as the key here (and not
4278 // its handle object), because the handle object isn't created until the
4279 // context is initialized; but thread_local_data is potentially accessed
4280 // before then.
4281 using EagerContextThreadLocalDataMap = absl::flat_hash_map<
4282 PyObject*, std::unique_ptr<tensorflow::EagerContextThreadLocalData>>;
4283 thread_local EagerContextThreadLocalDataMap*
4284 eager_context_thread_local_data_map = nullptr;
4285
4286 // Maps each py_eager_context object to default values.
4287 using EagerContextThreadLocalDataDefaultsMap =
4288 absl::flat_hash_map<PyObject*, EagerContextThreadLocalDataDefaults>;
4289 EagerContextThreadLocalDataDefaultsMap*
4290 eager_context_thread_local_data_defaults = nullptr;
4291
4292 } // namespace
4293
4294 namespace tensorflow {
4295
4296 void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
4297 PyObject* is_eager,
4298 PyObject* device_spec) {
4299 DCheckPyGilState();
4300 if (eager_context_thread_local_data_defaults == nullptr) {
4301 absl::LeakCheckDisabler disabler;
4302 eager_context_thread_local_data_defaults =
4303 new EagerContextThreadLocalDataDefaultsMap();
4304 }
4305 if (eager_context_thread_local_data_defaults->count(py_eager_context) > 0) {
4306 PyErr_SetString(PyExc_AssertionError,
4307 "MakeEagerContextThreadLocalData may not be called "
4308 "twice on the same eager Context object.");
4309 }
4310
4311 auto& defaults =
4312 (*eager_context_thread_local_data_defaults)[py_eager_context];
4313 Py_INCREF(is_eager);
4314 defaults.is_eager.reset(is_eager);
4315 Py_INCREF(device_spec);
4316 defaults.device_spec.reset(device_spec);
4317 }
4318
4319 EagerContextThreadLocalData* GetEagerContextThreadLocalData(
4320 PyObject* py_eager_context) {
4321 if (eager_context_thread_local_data_defaults == nullptr) {
4322 PyErr_SetString(PyExc_AssertionError,
4323 "MakeEagerContextThreadLocalData must be called "
4324 "before GetEagerContextThreadLocalData.");
4325 return nullptr;
4326 }
4327 auto defaults =
4328 eager_context_thread_local_data_defaults->find(py_eager_context);
4329 if (defaults == eager_context_thread_local_data_defaults->end()) {
4330 PyErr_SetString(PyExc_AssertionError,
4331 "MakeEagerContextThreadLocalData must be called "
4332 "before GetEagerContextThreadLocalData.");
4333 return nullptr;
4334 }
4335
4336 if (eager_context_thread_local_data_map == nullptr) {
4337 absl::LeakCheckDisabler disabler;
4338 eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
4339 }
4340 auto& thread_local_data =
4341 (*eager_context_thread_local_data_map)[py_eager_context];
4342
4343 if (!thread_local_data) {
4344 thread_local_data.reset(new EagerContextThreadLocalData());
4345
4346 Safe_PyObjectPtr is_eager(PyObject_CallFunctionObjArgs(
4347 defaults->second.is_eager.get(), nullptr));
4348 if (!is_eager) return nullptr;
4349 thread_local_data->is_eager = PyObject_IsTrue(is_eager.get());
4350
4351 #if PY_MAJOR_VERSION >= 3
4352 PyObject* scope_name = PyUnicode_FromString("");
4353 #else
4354 PyObject* scope_name = PyString_FromString("");
4355 #endif
4356 thread_local_data->scope_name.reset(scope_name);
4357
4358 #if PY_MAJOR_VERSION >= 3
4359 PyObject* device_name = PyUnicode_FromString("");
4360 #else
4361 PyObject* device_name = PyString_FromString("");
4362 #endif
4363 thread_local_data->device_name.reset(device_name);
4364
4365 Py_INCREF(defaults->second.device_spec.get());
4366 thread_local_data->device_spec.reset(defaults->second.device_spec.get());
4367
4368 Py_INCREF(Py_None);
4369 thread_local_data->function_call_options.reset(Py_None);
4370
4371 Py_INCREF(Py_None);
4372 thread_local_data->executor.reset(Py_None);
4373
4374 thread_local_data->op_callbacks.reset(PyList_New(0));
4375 }
4376 return thread_local_data.get();
4377 }
4378
4379 void DestroyEagerContextThreadLocalData(PyObject* py_eager_context) {
4380 DCheckPyGilState();
4381 if (eager_context_thread_local_data_defaults) {
4382 eager_context_thread_local_data_defaults->erase(py_eager_context);
4383 }
4384 if (eager_context_thread_local_data_map) {
4385 eager_context_thread_local_data_map->erase(py_eager_context);
4386 }
4387 }
4388
4389 } // namespace tensorflow
4390