1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstring>
17 #include <thread>
18
19 #include "tensorflow/python/eager/pywrap_tfe.h"
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/c/c_api.h"
24 #include "tensorflow/c/c_api_internal.h"
25 #include "tensorflow/c/eager/c_api_internal.h"
26 #include "tensorflow/c/eager/tape.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/gtl/compactptrset.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/python/eager/pywrap_tensor.h"
38 #include "tensorflow/python/lib/core/safe_ptr.h"
39 #include "tensorflow/python/util/util.h"
40
41 using tensorflow::string;
42 using tensorflow::strings::Printf;
43
44 namespace {
45
46 struct InputInfo {
InputInfo__anon177bbf5e0111::InputInfo47 InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
48
49 int i;
50 bool is_list = false;
51 };
52
53 // Takes in output gradients, returns input gradients.
54 typedef std::function<PyObject*(PyObject*)> PyBackwardFunction;
55
56 using AttrToInputsMap =
57 tensorflow::gtl::FlatMap<string,
58 tensorflow::gtl::InlinedVector<InputInfo, 4>>;
59
60 tensorflow::mutex all_attr_to_input_maps_lock(tensorflow::LINKER_INITIALIZED);
GetAllAttrToInputsMaps()61 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
62 static auto* all_attr_to_input_maps =
63 new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
64 return all_attr_to_input_maps;
65 }
66
GetAttrToInputsMap(const tensorflow::OpDef & op_def)67 AttrToInputsMap* GetAttrToInputsMap(const tensorflow::OpDef& op_def) {
68 tensorflow::mutex_lock l(all_attr_to_input_maps_lock);
69 auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
70
71 auto* output =
72 tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
73 if (output != nullptr) {
74 return output;
75 }
76
77 std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
78
79 // Store a list of InputIndex -> List of corresponding inputs.
80 for (int i = 0; i < op_def.input_arg_size(); i++) {
81 if (!op_def.input_arg(i).type_attr().empty()) {
82 auto it = m->find(op_def.input_arg(i).type_attr());
83 if (it == m->end()) {
84 it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
85 }
86 it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
87 }
88 }
89
90 auto* retval = m.get();
91 (*all_attr_to_input_maps)[op_def.name()] = m.release();
92
93 return retval;
94 }
95
96 struct FastPathOpExecInfo {
97 TFE_Context* ctx;
98 const char* device_name;
99 // The op def of the main op being executed.
100 const tensorflow::OpDef* op_def;
101
102 bool run_callbacks;
103 bool run_post_exec_callbacks;
104 bool run_gradient_callback;
105
106 // The op name of the main op being executed.
107 PyObject* name;
108 // The op type name of the main op being executed.
109 PyObject* op_name;
110 PyObject* callbacks;
111
112 // All the args passed into the FastPathOpExecInfo.
113 PyObject* args;
114
115 // DTypes can come from another input that has the same attr. So build that
116 // map.
117 const AttrToInputsMap* attr_to_inputs_map;
118 tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
119 };
120
121 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
122 bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
123 type* value) { \
124 if (check_fn(py_value)) { \
125 *value = static_cast<type>(parse_fn(py_value)); \
126 return true; \
127 } else { \
128 TF_SetStatus(status, TF_INVALID_ARGUMENT, \
129 tensorflow::strings::StrCat( \
130 "Expecting " #type " value for attr ", key, ", got ", \
131 py_value->ob_type->tp_name) \
132 .c_str()); \
133 return false; \
134 } \
135 }
136
137 #if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue,int,PyLong_Check,PyLong_AsLong)138 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
139 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
140 #else
141 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
142 #endif
143 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
144 #undef PARSE_VALUE
145
146 #if PY_MAJOR_VERSION < 3
147 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
148 int64_t* value) {
149 if (PyInt_Check(py_value)) {
150 *value = static_cast<int64_t>(PyInt_AsLong(py_value));
151 return true;
152 } else if (PyLong_Check(py_value)) {
153 *value = static_cast<int64_t>(PyLong_AsLong(py_value));
154 return true;
155 }
156 TF_SetStatus(
157 status, TF_INVALID_ARGUMENT,
158 tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
159 ", got ", py_value->ob_type->tp_name)
160 .c_str());
161 return false;
162 }
163 #endif
164
TensorShapeNumDims(PyObject * value)165 Py_ssize_t TensorShapeNumDims(PyObject* value) {
166 const auto size = PySequence_Size(value);
167 if (size == -1) {
168 // TensorShape.__len__ raises an error in the scenario where the shape is an
169 // unknown, which needs to be cleared.
170 // TODO(nareshmodi): ensure that this is actually a TensorShape.
171 PyErr_Clear();
172 }
173 return size;
174 }
175
IsInteger(PyObject * py_value)176 bool IsInteger(PyObject* py_value) {
177 #if PY_MAJOR_VERSION >= 3
178 return PyLong_Check(py_value);
179 #else
180 return PyInt_Check(py_value);
181 #endif
182 }
183
184 // This function considers a Dimension._value of None to be valid, and sets the
185 // value to be -1 in that case.
ParseDimensionValue(const string & key,PyObject * py_value,TF_Status * status,int64_t * value)186 bool ParseDimensionValue(const string& key, PyObject* py_value,
187 TF_Status* status, int64_t* value) {
188 if (IsInteger(py_value)) {
189 return ParseInt64Value(key, py_value, status, value);
190 }
191
192 tensorflow::Safe_PyObjectPtr dimension_value(
193 PyObject_GetAttrString(py_value, "_value"));
194 if (dimension_value == nullptr) {
195 TF_SetStatus(
196 status, TF_INVALID_ARGUMENT,
197 tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
198 ", got ", py_value->ob_type->tp_name)
199 .c_str());
200 return false;
201 }
202
203 if (dimension_value.get() == Py_None) {
204 *value = -1;
205 return true;
206 }
207
208 return ParseInt64Value(key, dimension_value.get(), status, value);
209 }
210
ParseStringValue(const string & key,PyObject * py_value,TF_Status * status,tensorflow::StringPiece * value)211 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
212 tensorflow::StringPiece* value) {
213 if (PyBytes_Check(py_value)) {
214 Py_ssize_t size = 0;
215 char* buf = nullptr;
216 if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
217 *value = tensorflow::StringPiece(buf, size);
218 return true;
219 }
220 #if PY_MAJOR_VERSION >= 3
221 if (PyUnicode_Check(py_value)) {
222 Py_ssize_t size = 0;
223 const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
224 if (buf == nullptr) return false;
225 *value = tensorflow::StringPiece(buf, size);
226 return true;
227 }
228 #endif
229 TF_SetStatus(
230 status, TF_INVALID_ARGUMENT,
231 tensorflow::strings::StrCat("Expecting a string value for attr ", key,
232 ", got ", py_value->ob_type->tp_name)
233 .c_str());
234 return false;
235 }
236
ParseBoolValue(const string & key,PyObject * py_value,TF_Status * status,unsigned char * value)237 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
238 unsigned char* value) {
239 *value = PyObject_IsTrue(py_value);
240 return true;
241 }
242
243 // The passed in py_value is expected to be an object of the python type
244 // dtypes.DType or an int.
ParseTypeValue(const string & key,PyObject * py_value,TF_Status * status,int * value)245 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
246 int* value) {
247 if (IsInteger(py_value)) {
248 return ParseIntValue(key, py_value, status, value);
249 }
250
251 tensorflow::Safe_PyObjectPtr py_type_enum(
252 PyObject_GetAttrString(py_value, "_type_enum"));
253 if (py_type_enum == nullptr) {
254 PyErr_Clear();
255 TF_SetStatus(
256 status, TF_INVALID_ARGUMENT,
257 tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
258 ", got ", py_value->ob_type->tp_name)
259 .c_str());
260 return false;
261 }
262
263 return ParseIntValue(key, py_type_enum.get(), status, value);
264 }
265
SetOpAttrList(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_list,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)266 bool SetOpAttrList(
267 TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
268 TF_AttrType type,
269 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
270 TF_Status* status) {
271 if (!PySequence_Check(py_list)) {
272 TF_SetStatus(
273 status, TF_INVALID_ARGUMENT,
274 tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
275 ", got ", py_list->ob_type->tp_name)
276 .c_str());
277 return false;
278 }
279 const int num_values = PySequence_Size(py_list);
280 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
281
282 #define PARSE_LIST(c_type, parse_fn) \
283 std::unique_ptr<c_type[]> values(new c_type[num_values]); \
284 for (int i = 0; i < num_values; ++i) { \
285 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
286 if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
287 }
288
289 if (type == TF_ATTR_STRING) {
290 std::unique_ptr<const void*[]> values(new const void*[num_values]);
291 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
292 for (int i = 0; i < num_values; ++i) {
293 tensorflow::StringPiece value;
294 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
295 if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
296 values[i] = value.data();
297 lengths[i] = value.size();
298 }
299 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
300 } else if (type == TF_ATTR_INT) {
301 PARSE_LIST(int64_t, ParseInt64Value);
302 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
303 } else if (type == TF_ATTR_FLOAT) {
304 PARSE_LIST(float, ParseFloatValue);
305 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
306 } else if (type == TF_ATTR_BOOL) {
307 PARSE_LIST(unsigned char, ParseBoolValue);
308 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
309 } else if (type == TF_ATTR_TYPE) {
310 PARSE_LIST(int, ParseTypeValue);
311 TFE_OpSetAttrTypeList(op, key,
312 reinterpret_cast<const TF_DataType*>(values.get()),
313 num_values);
314 } else if (type == TF_ATTR_SHAPE) {
315 // Make one pass through the input counting the total number of
316 // dims across all the input lists.
317 int total_dims = 0;
318 for (int i = 0; i < num_values; ++i) {
319 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
320 if (py_value.get() != Py_None) {
321 if (!PySequence_Check(py_value.get())) {
322 TF_SetStatus(
323 status, TF_INVALID_ARGUMENT,
324 tensorflow::strings::StrCat(
325 "Expecting None or sequence value for element", i,
326 " of attr ", key, ", got ", py_value->ob_type->tp_name)
327 .c_str());
328 return false;
329 }
330 const auto size = TensorShapeNumDims(py_value.get());
331 if (size >= 0) {
332 total_dims += size;
333 }
334 }
335 }
336 // Allocate a buffer that can fit all of the dims together.
337 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
338 // Copy the input dims into the buffer and set dims to point to
339 // the start of each list's dims.
340 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
341 std::unique_ptr<int[]> num_dims(new int[num_values]);
342 int64_t* offset = buffer.get();
343 for (int i = 0; i < num_values; ++i) {
344 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
345 if (py_value.get() == Py_None) {
346 dims[i] = nullptr;
347 num_dims[i] = -1;
348 } else {
349 const auto size = TensorShapeNumDims(py_value.get());
350 if (size == -1) {
351 dims[i] = nullptr;
352 num_dims[i] = -1;
353 continue;
354 }
355 dims[i] = offset;
356 num_dims[i] = size;
357 for (int j = 0; j < size; ++j) {
358 tensorflow::Safe_PyObjectPtr inner_py_value(
359 PySequence_ITEM(py_value.get(), j));
360 if (inner_py_value.get() == Py_None) {
361 *offset = -1;
362 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
363 offset)) {
364 return false;
365 }
366 ++offset;
367 }
368 }
369 }
370 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
371 status);
372 if (TF_GetCode(status) != TF_OK) return false;
373 } else if (type == TF_ATTR_FUNC) {
374 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
375 for (int i = 0; i < num_values; ++i) {
376 tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
377 // Allow:
378 // (1) String function name, OR
379 // (2) A Python object with a .name attribute
380 // (A crude test for being a
381 // tensorflow.python.framework.function._DefinedFunction)
382 // (which is what the various "defun" or "Defun" decorators do).
383 // And in the future also allow an object that can encapsulate
384 // the function name and its attribute values.
385 tensorflow::StringPiece func_name;
386 if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
387 PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
388 if (name_attr == nullptr ||
389 !ParseStringValue(key, name_attr, status, &func_name)) {
390 TF_SetStatus(
391 status, TF_INVALID_ARGUMENT,
392 tensorflow::strings::StrCat(
393 "unable to set function value attribute from a ",
394 py_value.get()->ob_type->tp_name,
395 " object. If you think this is an error, please file an "
396 "issue at "
397 "https://github.com/tensorflow/tensorflow/issues/new")
398 .c_str());
399 return false;
400 }
401 }
402 funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
403 if (TF_GetCode(status) != TF_OK) return false;
404 }
405 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
406 if (TF_GetCode(status) != TF_OK) return false;
407 } else {
408 TF_SetStatus(status, TF_UNIMPLEMENTED,
409 tensorflow::strings::StrCat("Attr ", key,
410 " has unhandled list type ", type)
411 .c_str());
412 return false;
413 }
414 #undef PARSE_LIST
415 return true;
416 }
417
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)418 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
419 TF_Status* status) {
420 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
421 for (const auto& attr : func.attr()) {
422 if (TF_GetCode(status) != TF_OK) return nullptr;
423 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
424 if (TF_GetCode(status) != TF_OK) return nullptr;
425 }
426 return func_op;
427 }
428
SetOpAttrListDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * key,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)429 void SetOpAttrListDefault(
430 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
431 const char* key, TF_AttrType type,
432 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
433 TF_Status* status) {
434 if (type == TF_ATTR_STRING) {
435 int num_values = attr.default_value().list().s_size();
436 std::unique_ptr<const void*[]> values(new const void*[num_values]);
437 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
438 (*attr_list_sizes)[key] = num_values;
439 for (int i = 0; i < num_values; i++) {
440 const string& v = attr.default_value().list().s(i);
441 values[i] = v.data();
442 lengths[i] = v.size();
443 }
444 TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
445 } else if (type == TF_ATTR_INT) {
446 int num_values = attr.default_value().list().i_size();
447 std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
448 (*attr_list_sizes)[key] = num_values;
449 for (int i = 0; i < num_values; i++) {
450 values[i] = attr.default_value().list().i(i);
451 }
452 TFE_OpSetAttrIntList(op, key, values.get(), num_values);
453 } else if (type == TF_ATTR_FLOAT) {
454 int num_values = attr.default_value().list().f_size();
455 std::unique_ptr<float[]> values(new float[num_values]);
456 (*attr_list_sizes)[key] = num_values;
457 for (int i = 0; i < num_values; i++) {
458 values[i] = attr.default_value().list().f(i);
459 }
460 TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
461 } else if (type == TF_ATTR_BOOL) {
462 int num_values = attr.default_value().list().b_size();
463 std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
464 (*attr_list_sizes)[key] = num_values;
465 for (int i = 0; i < num_values; i++) {
466 values[i] = attr.default_value().list().b(i);
467 }
468 TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
469 } else if (type == TF_ATTR_TYPE) {
470 int num_values = attr.default_value().list().type_size();
471 std::unique_ptr<int[]> values(new int[num_values]);
472 (*attr_list_sizes)[key] = num_values;
473 for (int i = 0; i < num_values; i++) {
474 values[i] = attr.default_value().list().type(i);
475 }
476 TFE_OpSetAttrTypeList(op, key,
477 reinterpret_cast<const TF_DataType*>(values.get()),
478 attr.default_value().list().type_size());
479 } else if (type == TF_ATTR_SHAPE) {
480 int num_values = attr.default_value().list().shape_size();
481 (*attr_list_sizes)[key] = num_values;
482 int total_dims = 0;
483 for (int i = 0; i < num_values; ++i) {
484 if (!attr.default_value().list().shape(i).unknown_rank()) {
485 total_dims += attr.default_value().list().shape(i).dim_size();
486 }
487 }
488 // Allocate a buffer that can fit all of the dims together.
489 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
490 // Copy the input dims into the buffer and set dims to point to
491 // the start of each list's dims.
492 std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
493 std::unique_ptr<int[]> num_dims(new int[num_values]);
494 int64_t* offset = buffer.get();
495 for (int i = 0; i < num_values; ++i) {
496 const auto& shape = attr.default_value().list().shape(i);
497 if (shape.unknown_rank()) {
498 dims[i] = nullptr;
499 num_dims[i] = -1;
500 } else {
501 for (int j = 0; j < shape.dim_size(); j++) {
502 *offset = shape.dim(j).size();
503 ++offset;
504 }
505 }
506 }
507 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
508 status);
509 } else if (type == TF_ATTR_FUNC) {
510 int num_values = attr.default_value().list().func_size();
511 (*attr_list_sizes)[key] = num_values;
512 std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
513 for (int i = 0; i < num_values; i++) {
514 funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
515 }
516 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
517 } else {
518 TF_SetStatus(status, TF_UNIMPLEMENTED,
519 "Lists of tensors are not yet implemented for default valued "
520 "attributes for an operation.");
521 }
522 }
523
SetOpAttrScalar(TFE_Context * ctx,TFE_Op * op,const char * key,PyObject * py_value,TF_AttrType type,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)524 bool SetOpAttrScalar(
525 TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
526 TF_AttrType type,
527 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
528 TF_Status* status) {
529 if (type == TF_ATTR_STRING) {
530 tensorflow::StringPiece value;
531 if (!ParseStringValue(key, py_value, status, &value)) return false;
532 TFE_OpSetAttrString(op, key, value.data(), value.size());
533 } else if (type == TF_ATTR_INT) {
534 int64_t value;
535 if (!ParseInt64Value(key, py_value, status, &value)) return false;
536 TFE_OpSetAttrInt(op, key, value);
537 // attr_list_sizes is set for all int attributes (since at this point we are
538 // not aware if that attribute might be used to calculate the size of an
539 // output list or not).
540 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
541 } else if (type == TF_ATTR_FLOAT) {
542 float value;
543 if (!ParseFloatValue(key, py_value, status, &value)) return false;
544 TFE_OpSetAttrFloat(op, key, value);
545 } else if (type == TF_ATTR_BOOL) {
546 unsigned char value;
547 if (!ParseBoolValue(key, py_value, status, &value)) return false;
548 TFE_OpSetAttrBool(op, key, value);
549 } else if (type == TF_ATTR_TYPE) {
550 int value;
551 if (!ParseTypeValue(key, py_value, status, &value)) return false;
552 TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
553 } else if (type == TF_ATTR_SHAPE) {
554 if (py_value == Py_None) {
555 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
556 } else {
557 if (!PySequence_Check(py_value)) {
558 TF_SetStatus(status, TF_INVALID_ARGUMENT,
559 tensorflow::strings::StrCat(
560 "Expecting None or sequence value for attr", key,
561 ", got ", py_value->ob_type->tp_name)
562 .c_str());
563 return false;
564 }
565 const auto num_dims = TensorShapeNumDims(py_value);
566 if (num_dims == -1) {
567 TFE_OpSetAttrShape(op, key, nullptr, -1, status);
568 return true;
569 }
570 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
571 for (int i = 0; i < num_dims; ++i) {
572 tensorflow::Safe_PyObjectPtr inner_py_value(
573 PySequence_ITEM(py_value, i));
574 if (inner_py_value.get() == Py_None) {
575 dims[i] = -1;
576 } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
577 &dims[i])) {
578 return false;
579 }
580 }
581 TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
582 }
583 if (TF_GetCode(status) != TF_OK) return false;
584 } else if (type == TF_ATTR_FUNC) {
585 // Allow:
586 // (1) String function name, OR
587 // (2) A Python object with a .name attribute
588 // (A crude test for being a
589 // tensorflow.python.framework.function._DefinedFunction)
590 // (which is what the various "defun" or "Defun" decorators do).
591 // And in the future also allow an object that can encapsulate
592 // the function name and its attribute values.
593 tensorflow::StringPiece func_name;
594 if (!ParseStringValue(key, py_value, status, &func_name)) {
595 PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
596 if (name_attr == nullptr ||
597 !ParseStringValue(key, name_attr, status, &func_name)) {
598 TF_SetStatus(
599 status, TF_INVALID_ARGUMENT,
600 tensorflow::strings::StrCat(
601 "unable to set function value attribute from a ",
602 py_value->ob_type->tp_name,
603 " object. If you think this is an error, please file an issue "
604 "at https://github.com/tensorflow/tensorflow/issues/new")
605 .c_str());
606 return false;
607 }
608 }
609 TF_SetStatus(status, TF_OK, "");
610 TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
611 } else {
612 TF_SetStatus(
613 status, TF_UNIMPLEMENTED,
614 tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
615 .c_str());
616 return false;
617 }
618 return true;
619 }
620
SetOpAttrScalarDefault(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)621 void SetOpAttrScalarDefault(
622 TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
623 const char* attr_name,
624 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
625 TF_Status* status) {
626 SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
627 if (default_value.value_case() == tensorflow::AttrValue::kI) {
628 (*attr_list_sizes)[attr_name] = default_value.i();
629 }
630 }
631
632 // start_index is the index at which the Tuple/List attrs will start getting
633 // processed.
SetOpAttrs(TFE_Context * ctx,TFE_Op * op,PyObject * attrs,int start_index,TF_Status * out_status)634 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
635 TF_Status* out_status) {
636 if (attrs == Py_None) return;
637 Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
638 if ((len & 1) != 0) {
639 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
640 "Expecting attrs tuple to have even length.");
641 return;
642 }
643 // Parse attrs
644 for (Py_ssize_t i = 0; i < len; i += 2) {
645 PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
646 PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
647 #if PY_MAJOR_VERSION >= 3
648 const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
649 : PyUnicode_AsUTF8(py_key);
650 #else
651 const char* key = PyBytes_AsString(py_key);
652 #endif
653 unsigned char is_list = 0;
654 const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
655 if (TF_GetCode(out_status) != TF_OK) return;
656 if (is_list != 0) {
657 if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
658 return;
659 } else {
660 if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
661 return;
662 }
663 }
664 }
665
666 // This function will set the op attrs required. If an attr has the value of
667 // None, then it will read the AttrDef to get the default value and set that
668 // instead. Any failure in this function will simply fall back to the slow
669 // path.
SetOpAttrWithDefaults(TFE_Context * ctx,TFE_Op * op,const tensorflow::OpDef::AttrDef & attr,const char * attr_name,PyObject * attr_value,tensorflow::gtl::FlatMap<string,tensorflow::int64> * attr_list_sizes,TF_Status * status)670 void SetOpAttrWithDefaults(
671 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
672 const char* attr_name, PyObject* attr_value,
673 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
674 TF_Status* status) {
675 unsigned char is_list = 0;
676 const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
677 if (TF_GetCode(status) != TF_OK) return;
678 if (attr_value == Py_None) {
679 if (is_list != 0) {
680 SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
681 status);
682 } else {
683 SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
684 attr_list_sizes, status);
685 }
686 } else {
687 if (is_list != 0) {
688 SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
689 status);
690 } else {
691 SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
692 status);
693 }
694 }
695 }
696
697 // Python subclass of Exception that is created on not ok Status.
698 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
699 PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
700
701 // Python subclass of Exception that is created to signal fallback.
702 PyObject* fallback_exception_class = nullptr;
703
704 // Python function that returns input gradients given output gradients.
705 PyObject* gradient_function = nullptr;
706
707 PyTypeObject* resource_variable_type = nullptr;
708
709 tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
710 tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
711
712 } // namespace
713
TFE_Py_Execute(TFE_Context * ctx,const char * device_name,const char * op_name,TFE_InputTensorHandles * inputs,PyObject * attrs,TFE_OutputTensorHandles * outputs,TF_Status * out_status)714 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
715 const char* op_name, TFE_InputTensorHandles* inputs,
716 PyObject* attrs, TFE_OutputTensorHandles* outputs,
717 TF_Status* out_status) {
718 TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
719 if (TF_GetCode(out_status) != TF_OK) return;
720 TFE_OpSetDevice(op, device_name, out_status);
721 if (TF_GetCode(out_status) == TF_OK) {
722 for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK;
723 ++i) {
724 TFE_OpAddInput(op, inputs->at(i), out_status);
725 }
726 }
727 if (TF_GetCode(out_status) == TF_OK) {
728 SetOpAttrs(ctx, op, attrs, 0, out_status);
729 }
730 Py_BEGIN_ALLOW_THREADS;
731 if (TF_GetCode(out_status) == TF_OK) {
732 int num_outputs = outputs->size();
733 TFE_Execute(op, outputs->data(), &num_outputs, out_status);
734 outputs->resize(num_outputs);
735 }
736 if (TF_GetCode(out_status) != TF_OK) {
737 TF_SetStatus(out_status, TF_GetCode(out_status),
738 tensorflow::strings::StrCat(TF_Message(out_status),
739 " [Op:", op_name, "]")
740 .c_str());
741 }
742 TFE_DeleteOp(op);
743 Py_END_ALLOW_THREADS;
744 }
745
TFE_Py_RegisterExceptionClass(PyObject * e)746 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
747 tensorflow::mutex_lock l(exception_class_mutex);
748 if (exception_class != nullptr) {
749 Py_DECREF(exception_class);
750 }
751 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
752 exception_class = nullptr;
753 PyErr_SetString(PyExc_TypeError,
754 "TFE_Py_RegisterExceptionClass: "
755 "Registered class should be subclass of Exception.");
756 return nullptr;
757 }
758
759 Py_INCREF(e);
760 exception_class = e;
761 Py_RETURN_NONE;
762 }
763
TFE_Py_RegisterResourceVariableType(PyObject * e)764 PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e) {
765 if (!PyType_Check(e)) {
766 PyErr_SetString(
767 PyExc_TypeError,
768 "TFE_Py_RegisterResourceVariableType: Need to register a type.");
769 return nullptr;
770 }
771
772 if (resource_variable_type != nullptr) {
773 Py_DECREF(resource_variable_type);
774 }
775
776 Py_INCREF(e);
777 resource_variable_type = reinterpret_cast<PyTypeObject*>(e);
778 Py_RETURN_NONE;
779 }
780
TFE_Py_RegisterFallbackExceptionClass(PyObject * e)781 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
782 if (fallback_exception_class != nullptr) {
783 Py_DECREF(fallback_exception_class);
784 }
785 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
786 fallback_exception_class = nullptr;
787 PyErr_SetString(PyExc_TypeError,
788 "TFE_Py_RegisterFallbackExceptionClass: "
789 "Registered class should be subclass of Exception.");
790 return nullptr;
791 } else {
792 Py_INCREF(e);
793 fallback_exception_class = e;
794 Py_RETURN_NONE;
795 }
796 }
797
TFE_Py_RegisterGradientFunction(PyObject * e)798 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
799 if (gradient_function != nullptr) {
800 Py_DECREF(gradient_function);
801 }
802 if (!PyCallable_Check(e)) {
803 gradient_function = nullptr;
804 PyErr_SetString(PyExc_TypeError,
805 "TFE_Py_RegisterBackwardFunctionGetter: "
806 "Registered object should be function.");
807 return nullptr;
808 } else {
809 Py_INCREF(e);
810 gradient_function = e;
811 Py_RETURN_NONE;
812 }
813 }
814
RaiseFallbackException(const char * message)815 void RaiseFallbackException(const char* message) {
816 if (fallback_exception_class != nullptr) {
817 PyErr_SetString(fallback_exception_class, message);
818 return;
819 }
820
821 PyErr_SetString(
822 PyExc_RuntimeError,
823 tensorflow::strings::StrCat(
824 "Fallback exception type not set, attempting to fallback due to ",
825 message)
826 .data());
827 }
828
MaybeRaiseExceptionFromTFStatus(TF_Status * status,PyObject * exception)829 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
830 if (TF_GetCode(status) == TF_OK) return 0;
831 const char* msg = TF_Message(status);
832 if (exception == nullptr) {
833 tensorflow::mutex_lock l(exception_class_mutex);
834 if (exception_class != nullptr) {
835 tensorflow::Safe_PyObjectPtr val(
836 Py_BuildValue("si", msg, TF_GetCode(status)));
837 if (PyErr_Occurred()) {
838 // NOTE: This hides the actual error (i.e. the reason `status` was not
839 // TF_OK), but there is nothing we can do at this point since we can't
840 // generate a reasonable error from the status.
841 // Consider adding a message explaining this.
842 return -1;
843 }
844 PyErr_SetObject(exception_class, val.get());
845 return -1;
846 } else {
847 exception = PyExc_RuntimeError;
848 }
849 }
850 // May be update already set exception.
851 PyErr_SetString(exception, msg);
852 return -1;
853 }
854
MaybeRaiseExceptionFromStatus(const tensorflow::Status & status,PyObject * exception)855 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
856 PyObject* exception) {
857 if (status.ok()) return 0;
858 const char* msg = status.error_message().c_str();
859 if (exception == nullptr) {
860 tensorflow::mutex_lock l(exception_class_mutex);
861 if (exception_class != nullptr) {
862 tensorflow::Safe_PyObjectPtr val(Py_BuildValue("si", msg, status.code()));
863 PyErr_SetObject(exception_class, val.get());
864 return -1;
865 } else {
866 exception = PyExc_RuntimeError;
867 }
868 }
869 // May be update already set exception.
870 PyErr_SetString(exception, msg);
871 return -1;
872 }
873
TFE_GetPythonString(PyObject * o)874 const char* TFE_GetPythonString(PyObject* o) {
875 #if PY_MAJOR_VERSION >= 3
876 if (PyBytes_Check(o)) {
877 return PyBytes_AsString(o);
878 } else {
879 return PyUnicode_AsUTF8(o);
880 }
881 #else
882 return PyBytes_AsString(o);
883 #endif
884 }
885
get_uid()886 int64_t get_uid() {
887 tensorflow::mutex_lock l(_uid_mutex);
888 return _uid++;
889 }
890
TFE_Py_UID()891 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
892
TFE_DeleteContextCapsule(PyObject * context)893 void TFE_DeleteContextCapsule(PyObject* context) {
894 TFE_Context* ctx =
895 reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
896 TFE_DeleteContext(ctx);
897 }
898
MakeInt(PyObject * integer)899 static tensorflow::int64 MakeInt(PyObject* integer) {
900 #if PY_MAJOR_VERSION >= 3
901 return PyLong_AsLong(integer);
902 #else
903 return PyInt_AsLong(integer);
904 #endif
905 }
906
FastTensorId(PyObject * tensor)907 static tensorflow::int64 FastTensorId(PyObject* tensor) {
908 if (EagerTensor_CheckExact(tensor)) {
909 return PyEagerTensor_ID(tensor);
910 }
911 PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
912 if (id_field == nullptr) {
913 return -1;
914 }
915 tensorflow::int64 id = MakeInt(id_field);
916 Py_DECREF(id_field);
917 return id;
918 }
919
FastTensorDtype(PyObject * tensor)920 static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
921 if (EagerTensor_CheckExact(tensor)) {
922 return PyEagerTensor_Dtype(tensor);
923 }
924 PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
925 if (dtype_field == nullptr) {
926 return tensorflow::DT_INVALID;
927 }
928 PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
929 Py_DECREF(dtype_field);
930 if (dtype_field == nullptr) {
931 return tensorflow::DT_INVALID;
932 }
933 tensorflow::int64 id = MakeInt(enum_field);
934 Py_DECREF(enum_field);
935 return static_cast<tensorflow::DataType>(id);
936 }
937
938 class PyTapeTensor {
939 public:
PyTapeTensor(tensorflow::int64 id,tensorflow::DataType dtype,const tensorflow::TensorShape & shape)940 PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
941 const tensorflow::TensorShape& shape)
942 : id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(tensorflow::int64 id,tensorflow::DataType dtype,PyObject * shape)943 PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
944 PyObject* shape)
945 : id_(id), dtype_(dtype), shape_(shape) {
946 Py_INCREF(absl::get<1>(shape_));
947 }
PyTapeTensor(const PyTapeTensor & other)948 PyTapeTensor(const PyTapeTensor& other) {
949 id_ = other.id_;
950 dtype_ = other.dtype_;
951 shape_ = other.shape_;
952 if (shape_.index() == 1) {
953 Py_INCREF(absl::get<1>(shape_));
954 }
955 }
956
~PyTapeTensor()957 ~PyTapeTensor() {
958 if (shape_.index() == 1) {
959 Py_DECREF(absl::get<1>(shape_));
960 }
961 }
962 PyObject* GetShape() const;
GetDType() const963 PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
GetID() const964 tensorflow::int64 GetID() const { return id_; }
965
966 private:
967 tensorflow::int64 id_;
968 tensorflow::DataType dtype_;
969 absl::variant<tensorflow::TensorShape, PyObject*> shape_;
970 };
971
972 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
973 PyTapeTensor> {
974 public:
PyVSpace(PyObject * py_vspace)975 explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
976 Py_INCREF(py_vspace_);
977 }
978
Initialize()979 tensorflow::Status Initialize() {
980 num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
981 if (num_elements_ == nullptr) {
982 return tensorflow::errors::InvalidArgument("invalid vspace");
983 }
984 aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
985 if (aggregate_fn_ == nullptr) {
986 return tensorflow::errors::InvalidArgument("invalid vspace");
987 }
988 zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
989 if (zeros_fn_ == nullptr) {
990 return tensorflow::errors::InvalidArgument("invalid vspace");
991 }
992 ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
993 if (ones_fn_ == nullptr) {
994 return tensorflow::errors::InvalidArgument("invalid vspace");
995 }
996 graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
997 if (graph_shape_fn_ == nullptr) {
998 return tensorflow::errors::InvalidArgument("invalid vspace");
999 }
1000 return tensorflow::Status::OK();
1001 }
1002
~PyVSpace()1003 ~PyVSpace() override {
1004 Py_XDECREF(num_elements_);
1005 Py_XDECREF(aggregate_fn_);
1006 Py_XDECREF(zeros_fn_);
1007 Py_XDECREF(ones_fn_);
1008 Py_XDECREF(graph_shape_fn_);
1009
1010 Py_DECREF(py_vspace_);
1011 }
1012
NumElements(PyObject * tensor) const1013 tensorflow::int64 NumElements(PyObject* tensor) const final {
1014 if (EagerTensor_CheckExact(tensor)) {
1015 return PyEagerTensor_NumElements(tensor);
1016 }
1017 PyObject* arglist =
1018 Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
1019 PyObject* result = PyEval_CallObject(num_elements_, arglist);
1020 Py_DECREF(arglist);
1021 if (result == nullptr) {
1022 // The caller detects whether a python exception has been raised.
1023 return -1;
1024 }
1025 tensorflow::int64 r = MakeInt(result);
1026 Py_DECREF(result);
1027 return r;
1028 }
1029
AggregateGradients(tensorflow::gtl::ArraySlice<PyObject * > gradient_tensors) const1030 PyObject* AggregateGradients(
1031 tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
1032 PyObject* list = PyList_New(gradient_tensors.size());
1033 for (int i = 0; i < gradient_tensors.size(); ++i) {
1034 // Note: stealing a reference to the gradient tensors.
1035 CHECK(gradient_tensors[i] != nullptr);
1036 CHECK(gradient_tensors[i] != Py_None);
1037 PyList_SET_ITEM(list, i,
1038 reinterpret_cast<PyObject*>(gradient_tensors[i]));
1039 }
1040 PyObject* arglist = Py_BuildValue("(O)", list);
1041 CHECK(arglist != nullptr);
1042 PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
1043 Py_DECREF(arglist);
1044 Py_DECREF(list);
1045 return result;
1046 }
1047
MarkAsResult(PyObject * gradient) const1048 void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
1049
Zeros(const PyTapeTensor & tensor) const1050 PyObject* Zeros(const PyTapeTensor& tensor) const final {
1051 if (PyErr_Occurred()) {
1052 return nullptr;
1053 }
1054 PyObject* py_shape = tensor.GetShape();
1055 if (PyErr_Occurred()) {
1056 return nullptr;
1057 }
1058 PyObject* py_dtype = tensor.GetDType();
1059 if (PyErr_Occurred()) {
1060 Py_DECREF(py_shape);
1061 return nullptr;
1062 }
1063 PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
1064 PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
1065 Py_DECREF(arg_list);
1066 Py_DECREF(py_dtype);
1067 Py_DECREF(py_shape);
1068 return reinterpret_cast<PyObject*>(result);
1069 }
1070
Ones(const PyTapeTensor & tensor) const1071 PyObject* Ones(const PyTapeTensor& tensor) const final {
1072 if (PyErr_Occurred()) {
1073 return nullptr;
1074 }
1075 PyObject* py_shape = tensor.GetShape();
1076 PyObject* py_dtype = tensor.GetDType();
1077 PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
1078 PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
1079 Py_DECREF(arg_list);
1080 Py_DECREF(py_dtype);
1081 Py_DECREF(py_shape);
1082 return result;
1083 }
1084
GraphShape(PyObject * tensor) const1085 PyObject* GraphShape(PyObject* tensor) const {
1086 PyObject* arg_list = Py_BuildValue("(O)", tensor);
1087 PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
1088 Py_DECREF(arg_list);
1089 return result;
1090 }
1091
CallBackwardFunction(PyBackwardFunction * backward_function,tensorflow::gtl::ArraySlice<PyObject * > output_gradients,std::vector<PyObject * > * result) const1092 tensorflow::Status CallBackwardFunction(
1093 PyBackwardFunction* backward_function,
1094 tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
1095 std::vector<PyObject*>* result) const final {
1096 PyObject* grads = PyTuple_New(output_gradients.size());
1097 for (int i = 0; i < output_gradients.size(); ++i) {
1098 if (output_gradients[i] == nullptr) {
1099 Py_INCREF(Py_None);
1100 PyTuple_SET_ITEM(grads, i, Py_None);
1101 } else {
1102 PyTuple_SET_ITEM(grads, i,
1103 reinterpret_cast<PyObject*>(output_gradients[i]));
1104 }
1105 }
1106 PyObject* py_result = (*backward_function)(grads);
1107 Py_DECREF(grads);
1108 if (py_result == nullptr) {
1109 return tensorflow::errors::Internal("gradient function threw exceptions");
1110 }
1111 result->clear();
1112 PyObject* seq =
1113 PySequence_Fast(py_result, "expected a sequence of gradients");
1114 if (seq == nullptr) {
1115 return tensorflow::errors::InvalidArgument(
1116 "gradient function did not return a list");
1117 }
1118 int len = PySequence_Fast_GET_SIZE(seq);
1119 VLOG(1) << "Gradient length is " << len;
1120 result->reserve(len);
1121 for (int i = 0; i < len; ++i) {
1122 PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
1123 if (item == Py_None) {
1124 result->push_back(nullptr);
1125 } else {
1126 Py_INCREF(item);
1127 result->push_back(item);
1128 }
1129 }
1130 Py_DECREF(seq);
1131 Py_DECREF(py_result);
1132 return tensorflow::Status::OK();
1133 }
1134
DeleteGradient(PyObject * tensor) const1135 void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
1136
1137 private:
1138 PyObject* py_vspace_;
1139
1140 PyObject* num_elements_;
1141 PyObject* aggregate_fn_;
1142 PyObject* zeros_fn_;
1143 PyObject* ones_fn_;
1144 PyObject* graph_shape_fn_;
1145 };
1146 PyVSpace* py_vspace = nullptr;
1147
TFE_Py_RegisterVSpace(PyObject * e)1148 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
1149 if (py_vspace != nullptr) {
1150 delete py_vspace;
1151 }
1152
1153 py_vspace = new PyVSpace(e);
1154 auto status = py_vspace->Initialize();
1155 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1156 delete py_vspace;
1157 return nullptr;
1158 }
1159
1160 Py_RETURN_NONE;
1161 }
1162
GetShape() const1163 PyObject* PyTapeTensor::GetShape() const {
1164 if (shape_.index() == 0) {
1165 auto& shape = absl::get<0>(shape_);
1166 PyObject* py_shape = PyTuple_New(shape.dims());
1167 for (int i = 0; i < shape.dims(); ++i) {
1168 PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
1169 }
1170
1171 return py_shape;
1172 }
1173
1174 return py_vspace->GraphShape(absl::get<1>(shape_));
1175 }
1176
1177 class GradientTape
1178 : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1179 PyTapeTensor> {
1180 public:
GradientTape(bool persistent,bool watch_accessed_variables)1181 explicit GradientTape(bool persistent, bool watch_accessed_variables)
1182 : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
1183 PyTapeTensor>(persistent),
1184 watch_accessed_variables_(watch_accessed_variables) {}
1185
~GradientTape()1186 virtual ~GradientTape() {
1187 for (const IdAndVariable& v : watched_variables_) {
1188 Py_DECREF(v.variable);
1189 }
1190 }
1191
VariableAccessed(PyObject * v)1192 void VariableAccessed(PyObject* v) {
1193 if (watch_accessed_variables_) {
1194 WatchVariable(v);
1195 }
1196 }
1197
WatchVariable(PyObject * v)1198 void WatchVariable(PyObject* v) {
1199 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
1200 if (handle == nullptr) {
1201 return;
1202 }
1203 tensorflow::int64 id = FastTensorId(handle.get());
1204
1205 if (!PyErr_Occurred()) {
1206 this->Watch(id);
1207 }
1208
1209 tensorflow::mutex_lock l(watched_variables_mu_);
1210 auto insert_result = watched_variables_.emplace(id, v);
1211
1212 if (insert_result.second) {
1213 // Only increment the reference count if we aren't already watching this
1214 // variable.
1215 Py_INCREF(v);
1216 }
1217 }
1218
GetVariablesAsPyTuple()1219 PyObject* GetVariablesAsPyTuple() {
1220 tensorflow::mutex_lock l(watched_variables_mu_);
1221 PyObject* result = PyTuple_New(watched_variables_.size());
1222 Py_ssize_t pos = 0;
1223 for (const IdAndVariable& id_and_variable : watched_variables_) {
1224 PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
1225 Py_INCREF(id_and_variable.variable);
1226 }
1227 return result;
1228 }
1229
1230 private:
1231 // We store an IdAndVariable in the map since the map needs to be locked
1232 // during insert, but should not call back into python during insert to avoid
1233 // deadlocking with the GIL.
1234 struct IdAndVariable {
1235 tensorflow::int64 id;
1236 PyObject* variable;
1237
IdAndVariableGradientTape::IdAndVariable1238 IdAndVariable(tensorflow::int64 id, PyObject* variable)
1239 : id(id), variable(variable) {}
1240 };
1241 struct CompareById {
operator ()GradientTape::CompareById1242 bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
1243 return lhs.id < rhs.id;
1244 }
1245 };
1246
1247 bool watch_accessed_variables_;
1248 tensorflow::mutex watched_variables_mu_;
1249 std::set<IdAndVariable, CompareById> watched_variables_
1250 GUARDED_BY(watched_variables_mu_);
1251 };
1252
1253 typedef struct {
1254 PyObject_HEAD
1255 /* Type-specific fields go here. */
1256 GradientTape* tape;
1257 } TFE_Py_Tape;
1258
TFE_Py_Tape_Delete(PyObject * tape)1259 static void TFE_Py_Tape_Delete(PyObject* tape) {
1260 delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
1261 Py_TYPE(tape)->tp_free(tape);
1262 }
1263
1264 static PyTypeObject TFE_Py_Tape_Type = {
1265 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
1266 sizeof(TFE_Py_Tape), /* tp_basicsize */
1267 0, /* tp_itemsize */
1268 &TFE_Py_Tape_Delete, /* tp_dealloc */
1269 nullptr, /* tp_print */
1270 nullptr, /* tp_getattr */
1271 nullptr, /* tp_setattr */
1272 nullptr, /* tp_reserved */
1273 nullptr, /* tp_repr */
1274 nullptr, /* tp_as_number */
1275 nullptr, /* tp_as_sequence */
1276 nullptr, /* tp_as_mapping */
1277 nullptr, /* tp_hash */
1278 nullptr, /* tp_call */
1279 nullptr, /* tp_str */
1280 nullptr, /* tp_getattro */
1281 nullptr, /* tp_setattro */
1282 nullptr, /* tp_as_buffer */
1283 Py_TPFLAGS_DEFAULT, /* tp_flags */
1284 "TFE_Py_Tape objects", /* tp_doc */
1285 };
1286
1287 // Note: in the current design no mutex is needed here because of the python
1288 // GIL, which is always held when any TFE_Py_* methods are called. We should
1289 // revisit this if/when decide to not hold the GIL while manipulating the tape
1290 // stack.
GetTapeSet()1291 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
1292 thread_local tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set{
1293 nullptr};
1294 if (tape_set == nullptr) {
1295 tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>;
1296 }
1297 return tape_set;
1298 }
1299
1300 // A safe copy of the current tapeset. Does not get affected by other python
1301 // threads changing the set of active tapes.
1302 class SafeTapeSet {
1303 public:
SafeTapeSet()1304 SafeTapeSet() : tape_set_(*GetTapeSet()) {
1305 for (auto* tape : tape_set_) {
1306 Py_INCREF(tape);
1307 }
1308 }
1309
~SafeTapeSet()1310 ~SafeTapeSet() {
1311 for (auto* tape : tape_set_) {
1312 Py_DECREF(tape);
1313 }
1314 }
1315
begin()1316 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator begin() {
1317 return tape_set_.begin();
1318 }
1319
end()1320 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator end() {
1321 return tape_set_.end();
1322 }
1323
1324 private:
1325 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_;
1326 };
1327
ThreadTapeIsStopped()1328 bool* ThreadTapeIsStopped() {
1329 thread_local bool thread_tape_is_stopped{false};
1330 return &thread_tape_is_stopped;
1331 }
1332
TFE_Py_TapeSetStopOnThread()1333 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
1334
TFE_Py_TapeSetRestartOnThread()1335 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
1336
TFE_Py_TapeSetNew(PyObject * persistent,PyObject * watch_accessed_variables)1337 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
1338 PyObject* watch_accessed_variables) {
1339 TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
1340 if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
1341 TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
1342 tape->tape = new GradientTape(persistent == Py_True,
1343 watch_accessed_variables == Py_True);
1344 Py_INCREF(tape);
1345 GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
1346 return reinterpret_cast<PyObject*>(tape);
1347 }
1348
TFE_Py_TapeSetAdd(PyObject * tape)1349 void TFE_Py_TapeSetAdd(PyObject* tape) {
1350 Py_INCREF(tape);
1351 if (!GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)).second) {
1352 // Already exists in the tape set.
1353 Py_DECREF(tape);
1354 }
1355 }
1356
TFE_Py_TapeSetIsEmpty()1357 PyObject* TFE_Py_TapeSetIsEmpty() {
1358 if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
1359 Py_RETURN_TRUE;
1360 }
1361 Py_RETURN_FALSE;
1362 }
1363
TFE_Py_TapeSetRemove(PyObject * tape)1364 void TFE_Py_TapeSetRemove(PyObject* tape) {
1365 auto* stack = GetTapeSet();
1366 stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
1367 // We kept a reference to the tape in the set to ensure it wouldn't get
1368 // deleted under us; cleaning it up here.
1369 Py_DECREF(tape);
1370 }
1371
MakeIntList(PyObject * list)1372 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
1373 if (list == Py_None) {
1374 return {};
1375 }
1376 PyObject* seq = PySequence_Fast(list, "expected a sequence");
1377 if (seq == nullptr) {
1378 return {};
1379 }
1380 int len = PySequence_Size(list);
1381 std::vector<tensorflow::int64> tensor_ids;
1382 tensor_ids.reserve(len);
1383 for (int i = 0; i < len; ++i) {
1384 PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
1385 #if PY_MAJOR_VERSION >= 3
1386 if (PyLong_Check(item)) {
1387 #else
1388 if (PyLong_Check(item) || PyInt_Check(item)) {
1389 #endif
1390 tensorflow::int64 id = MakeInt(item);
1391 tensor_ids.push_back(id);
1392 } else {
1393 tensor_ids.push_back(-1);
1394 }
1395 }
1396 Py_DECREF(seq);
1397 return tensor_ids;
1398 }
1399
1400 PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
1401 if (tensors == Py_None) {
1402 Py_RETURN_FALSE;
1403 }
1404 if (*ThreadTapeIsStopped()) {
1405 Py_RETURN_FALSE;
1406 }
1407 auto* tape_set_ptr = GetTapeSet();
1408 if (tape_set_ptr->empty()) {
1409 Py_RETURN_FALSE;
1410 }
1411 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
1412 if (seq == nullptr) {
1413 return nullptr;
1414 }
1415 int len = PySequence_Fast_GET_SIZE(seq);
1416 // TODO(apassos) consider not building a list and changing the API to check
1417 // each tensor individually.
1418 std::vector<tensorflow::int64> tensor_ids;
1419 std::vector<tensorflow::DataType> dtypes;
1420 tensor_ids.reserve(len);
1421 dtypes.reserve(len);
1422 for (int i = 0; i < len; ++i) {
1423 PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
1424 tensor_ids.push_back(FastTensorId(item));
1425 dtypes.push_back(FastTensorDtype(item));
1426 }
1427 Py_DECREF(seq);
1428 auto tape_set = *tape_set_ptr;
1429 for (TFE_Py_Tape* tape : tape_set) {
1430 if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
1431 Py_RETURN_TRUE;
1432 }
1433 }
1434 Py_RETURN_FALSE;
1435 }
1436
1437 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
1438 if (*ThreadTapeIsStopped()) {
1439 return;
1440 }
1441 tensorflow::int64 tensor_id = FastTensorId(tensor);
1442 if (PyErr_Occurred()) {
1443 return;
1444 }
1445 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
1446 }
1447
1448 bool ListContainsNone(PyObject* list) {
1449 if (list == Py_None) return true;
1450 tensorflow::Safe_PyObjectPtr seq(
1451 PySequence_Fast(list, "expected a sequence"));
1452 if (seq == nullptr) {
1453 return false;
1454 }
1455
1456 int len = PySequence_Size(list);
1457 for (int i = 0; i < len; ++i) {
1458 PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
1459 if (item == Py_None) return true;
1460 }
1461
1462 return false;
1463 }
1464
1465 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
1466 if (EagerTensor_CheckExact(tensor)) {
1467 TFE_TensorHandle* t = EagerTensor_Handle(tensor);
1468 tensorflow::int64 id = PyEagerTensor_ID(tensor);
1469 tensorflow::TensorShape tensor_shape;
1470 const tensorflow::Status status = t->handle->Shape(&tensor_shape);
1471
1472 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
1473 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
1474 tensorflow::TensorShape({}));
1475 } else {
1476 return PyTapeTensor(id, t->handle->dtype, tensor_shape);
1477 }
1478 }
1479 tensorflow::int64 id = FastTensorId(tensor);
1480 if (PyErr_Occurred()) {
1481 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
1482 tensorflow::TensorShape({}));
1483 }
1484 PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
1485 PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
1486 Py_DECREF(dtype_object);
1487 tensorflow::DataType dtype =
1488 static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
1489 Py_DECREF(dtype_enum);
1490 if (PyErr_Occurred()) {
1491 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
1492 tensorflow::TensorShape({}));
1493 }
1494 static char _shape_tuple[] = "_shape_tuple";
1495 PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
1496 if (PyErr_Occurred()) {
1497 return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
1498 tensorflow::TensorShape({}));
1499 }
1500
1501 if (ListContainsNone(shape_tuple)) {
1502 return PyTapeTensor(id, dtype, tensor);
1503 }
1504
1505 auto l = MakeIntList(shape_tuple);
1506 Py_DECREF(shape_tuple);
1507 // Replace -1, which represents accidental Nones which can occur in graph mode
1508 // and can cause errors in shape cosntruction with 0s.
1509 for (auto& c : l) {
1510 if (c < 0) {
1511 c = 0;
1512 }
1513 }
1514 tensorflow::TensorShape shape(l);
1515 return PyTapeTensor(id, dtype, shape);
1516 }
1517
1518 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
1519 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
1520 if (seq == nullptr) {
1521 return {};
1522 }
1523 int len = PySequence_Fast_GET_SIZE(seq);
1524 std::vector<tensorflow::int64> list;
1525 list.reserve(len);
1526 for (int i = 0; i < len; ++i) {
1527 PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
1528 list.push_back(FastTensorId(tensor));
1529 if (PyErr_Occurred()) {
1530 Py_DECREF(seq);
1531 return list;
1532 }
1533 }
1534 Py_DECREF(seq);
1535 return list;
1536 }
1537
1538 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
1539 if (*ThreadTapeIsStopped()) {
1540 return;
1541 }
1542 for (TFE_Py_Tape* tape : SafeTapeSet()) {
1543 tape->tape->VariableAccessed(variable);
1544 }
1545 }
1546
1547 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
1548 if (*ThreadTapeIsStopped()) {
1549 return;
1550 }
1551 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
1552 }
1553
1554 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
1555 return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
1556 }
1557
1558 namespace {
1559 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
1560 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
1561 if (seq == nullptr) {
1562 return {};
1563 }
1564 int len = PySequence_Fast_GET_SIZE(seq);
1565 std::vector<tensorflow::DataType> list;
1566 list.reserve(len);
1567 for (int i = 0; i < len; ++i) {
1568 PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
1569 list.push_back(FastTensorDtype(tensor));
1570 }
1571 Py_DECREF(seq);
1572 return list;
1573 }
1574
1575 void TapeSetRecordOperation(
1576 PyObject* op_type, PyObject* output_tensors,
1577 const std::vector<tensorflow::int64>& input_ids,
1578 const std::vector<tensorflow::DataType>& input_dtypes,
1579 const std::function<PyBackwardFunction*()>& backward_function_getter,
1580 const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
1581 std::vector<PyTapeTensor> output_info;
1582 PyObject* seq = PySequence_Fast(output_tensors,
1583 "expected a sequence of integer tensor ids");
1584 int len = PySequence_Size(output_tensors);
1585 if (PyErr_Occurred()) return;
1586 output_info.reserve(len);
1587 for (int i = 0; i < len; ++i) {
1588 output_info.push_back(
1589 TapeTensorFromTensor(PySequence_Fast_GET_ITEM(seq, i)));
1590 if (PyErr_Occurred() != nullptr) {
1591 Py_DECREF(seq);
1592 return;
1593 }
1594 }
1595 Py_DECREF(seq);
1596 string op_type_str;
1597 if (PyBytes_Check(op_type)) {
1598 op_type_str = PyBytes_AsString(op_type);
1599 } else if (PyUnicode_Check(op_type)) {
1600 #if PY_MAJOR_VERSION >= 3
1601 op_type_str = PyUnicode_AsUTF8(op_type);
1602 #else
1603 PyObject* py_str = PyUnicode_AsUTF8String(op_type);
1604 if (py_str == nullptr) return;
1605 op_type_str = PyBytes_AS_STRING(py_str);
1606 Py_DECREF(py_str);
1607 #endif
1608 } else {
1609 PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
1610 return;
1611 }
1612
1613 for (TFE_Py_Tape* tape : SafeTapeSet()) {
1614 tape->tape->RecordOperation(op_type_str, output_info, input_ids,
1615 input_dtypes, backward_function_getter,
1616 backward_function_killer);
1617 }
1618 }
1619 } // namespace
1620
1621 void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
1622 PyObject* input_tensors,
1623 PyObject* backward_function) {
1624 if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
1625 return;
1626 }
1627 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
1628 if (PyErr_Occurred()) return;
1629
1630 std::vector<tensorflow::DataType> input_dtypes =
1631 MakeTensorDtypeList(input_tensors);
1632 if (PyErr_Occurred()) return;
1633
1634 TapeSetRecordOperation(
1635 op_type, output_tensors, input_ids, input_dtypes,
1636 [backward_function]() {
1637 Py_INCREF(backward_function);
1638 PyBackwardFunction* function =
1639 new PyBackwardFunction([backward_function](PyObject* out_grads) {
1640 return PyObject_CallObject(backward_function, out_grads);
1641 });
1642 return function;
1643 },
1644 [backward_function](PyBackwardFunction* py_backward_function) {
1645 Py_DECREF(backward_function);
1646 delete py_backward_function;
1647 });
1648 }
1649
1650 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
1651 for (TFE_Py_Tape* tape : SafeTapeSet()) {
1652 tape->tape->DeleteTrace(tensor_id);
1653 }
1654 }
1655
1656 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
1657 PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
1658 if (seq == nullptr) {
1659 return {};
1660 }
1661 int len = PySequence_Fast_GET_SIZE(seq);
1662 std::vector<PyObject*> list;
1663 list.reserve(len);
1664 for (int i = 0; i < len; ++i) {
1665 list.push_back(PySequence_Fast_GET_ITEM(seq, i));
1666 }
1667 Py_DECREF(seq);
1668 return list;
1669 }
1670
1671 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
1672 PyObject* sources, PyObject* output_gradients,
1673 PyObject* unconnected_gradients,
1674 TF_Status* status) {
1675 TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
1676 if (!tape_obj->tape->IsPersistent()) {
1677 auto* tape_set = GetTapeSet();
1678 if (tape_set->find(tape_obj) != tape_set->end()) {
1679 PyErr_SetString(PyExc_RuntimeError,
1680 "gradient() cannot be invoked within the "
1681 "GradientTape context (i.e., while operations are being "
1682 "recorded). Either move the call to gradient() to be "
1683 "outside the 'with tf.GradientTape' block, or "
1684 "use a persistent tape: "
1685 "'with tf.GradientTape(persistent=true)'");
1686 return nullptr;
1687 }
1688 }
1689
1690 std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
1691 if (PyErr_Occurred()) {
1692 return nullptr;
1693 }
1694 std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
1695 if (PyErr_Occurred()) {
1696 return nullptr;
1697 }
1698 tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
1699 sources_vec.end());
1700
1701 tensorflow::Safe_PyObjectPtr seq =
1702 tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
1703 int len = PySequence_Fast_GET_SIZE(seq.get());
1704 tensorflow::gtl::FlatMap<tensorflow::int64, PyTapeTensor>
1705 source_tensors_that_are_targets;
1706 for (int i = 0; i < len; ++i) {
1707 tensorflow::int64 target_id = target_vec[i];
1708 if (sources_set.find(target_id) != sources_set.end()) {
1709 auto tensor = PySequence_Fast_GET_ITEM(seq.get(), i);
1710 source_tensors_that_are_targets.insert(
1711 std::make_pair(target_id, TapeTensorFromTensor(tensor)));
1712 }
1713 if (PyErr_Occurred()) {
1714 return nullptr;
1715 }
1716 }
1717 if (PyErr_Occurred()) {
1718 return nullptr;
1719 }
1720
1721 std::vector<PyObject*> outgrad_vec;
1722 if (output_gradients != Py_None) {
1723 outgrad_vec = MakeTensorList(output_gradients);
1724 if (PyErr_Occurred()) {
1725 return nullptr;
1726 }
1727 for (PyObject* tensor : outgrad_vec) {
1728 // Calling the backward function will eat a reference to the tensors in
1729 // outgrad_vec, so we need to increase their reference count.
1730 Py_INCREF(tensor);
1731 }
1732 }
1733 std::vector<PyObject*> result;
1734 status->status = tape_obj->tape->ComputeGradient(
1735 *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
1736 outgrad_vec, &result);
1737 if (!status->status.ok()) {
1738 if (PyErr_Occurred()) {
1739 // Do not propagate the erroneous status as that would swallow the
1740 // exception which caused the problem.
1741 status->status = tensorflow::Status::OK();
1742 }
1743 return nullptr;
1744 }
1745
1746 bool unconnected_gradients_zero =
1747 strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
1748 std::vector<PyObject*> sources_obj;
1749 if (unconnected_gradients_zero) {
1750 sources_obj = MakeTensorList(sources);
1751 }
1752
1753 if (!result.empty()) {
1754 PyObject* py_result = PyList_New(result.size());
1755 tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
1756 for (int i = 0; i < result.size(); ++i) {
1757 if (result[i] == nullptr) {
1758 if (unconnected_gradients_zero) {
1759 // generate a zeros tensor in the shape of sources[i]
1760 tensorflow::DataType dtype = FastTensorDtype(sources_obj[i]);
1761 PyTapeTensor tensor =
1762 PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
1763 result[i] = py_vspace->Zeros(tensor);
1764 } else {
1765 Py_INCREF(Py_None);
1766 result[i] = Py_None;
1767 }
1768 } else if (seen_results.find(result[i]) != seen_results.end()) {
1769 Py_INCREF(result[i]);
1770 }
1771 seen_results.insert(result[i]);
1772 PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
1773 }
1774 return py_result;
1775 }
1776 return PyList_New(0);
1777 }
1778
1779 namespace {
1780 static const int kFastPathExecuteInputStartIndex = 5;
1781
1782 PyObject* GetPythonObjectFromString(const char* s) {
1783 #if PY_MAJOR_VERSION >= 3
1784 return PyUnicode_FromString(s);
1785 #else
1786 return PyBytes_FromString(s);
1787 #endif
1788 }
1789
1790 PyObject* GetPythonObjectFromInt(int num) {
1791 #if PY_MAJOR_VERSION >= 3
1792 return PyLong_FromLong(num);
1793 #else
1794 return PyInt_FromLong(num);
1795 #endif
1796 }
1797
1798 bool CheckResourceVariable(PyObject* item) {
1799 return PyObject_TypeCheck(item, resource_variable_type);
1800 }
1801
1802 bool IsNumberType(PyObject* item) {
1803 #if PY_MAJOR_VERSION >= 3
1804 return PyFloat_Check(item) || PyLong_Check(item);
1805 #else
1806 return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
1807 #endif
1808 }
1809
1810 bool CheckOneInput(PyObject* item) {
1811 if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
1812 PyArray_Check(item) || IsNumberType(item)) {
1813 return true;
1814 }
1815
1816 // Sequences are not properly handled. Sequences with purely python numeric
1817 // types work, but sequences with mixes of EagerTensors and python numeric
1818 // types don't work.
1819 // TODO(nareshmodi): fix
1820 return false;
1821 }
1822
1823 bool CheckInputsOk(PyObject* seq, int start_index,
1824 const tensorflow::OpDef& op_def) {
1825 for (int i = 0; i < op_def.input_arg_size(); i++) {
1826 PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
1827 if (!op_def.input_arg(i).number_attr().empty() ||
1828 !op_def.input_arg(i).type_list_attr().empty()) {
1829 // This item should be a seq input.
1830 if (!PySequence_Check(item)) {
1831 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
1832 << "\", Input \"" << op_def.input_arg(i).name()
1833 << "\" since we expected a sequence, but got "
1834 << item->ob_type->tp_name;
1835 return false;
1836 }
1837 for (Py_ssize_t j = 0; j < PySequence_Fast_GET_SIZE(item); j++) {
1838 PyObject* inner_item = PySequence_Fast_GET_ITEM(item, j);
1839 if (!CheckOneInput(inner_item)) {
1840 VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
1841 << "\", Input \"" << op_def.input_arg(i).name()
1842 << "\", Index " << j
1843 << " since we expected an EagerTensor/ResourceVariable, "
1844 "but got "
1845 << inner_item->ob_type->tp_name;
1846 return false;
1847 }
1848 }
1849 } else if (!CheckOneInput(item)) {
1850 VLOG(1)
1851 << "Falling back to slow path for Op \"" << op_def.name()
1852 << "\", Input \"" << op_def.input_arg(i).name()
1853 << "\" since we expected an EagerTensor/ResourceVariable, but got "
1854 << item->ob_type->tp_name;
1855 return false;
1856 }
1857 }
1858
1859 return true;
1860 }
1861
1862 PyObject* MaybeGetDType(PyObject* item) {
1863 if (EagerTensor_CheckExact(item)) {
1864 tensorflow::Safe_PyObjectPtr py_dtype(
1865 PyObject_GetAttrString(item, "dtype"));
1866 return PyObject_GetAttrString(py_dtype.get(), "_type_enum");
1867 }
1868
1869 if (CheckResourceVariable(item)) {
1870 tensorflow::Safe_PyObjectPtr py_dtype(
1871 PyObject_GetAttrString(item, "_dtype"));
1872 return PyObject_GetAttrString(py_dtype.get(), "_type_enum");
1873 }
1874
1875 return nullptr;
1876 }
1877
1878 PyObject* MaybeGetDTypeForAttr(const string& attr,
1879 FastPathOpExecInfo* op_exec_info) {
1880 auto cached_it = op_exec_info->cached_dtypes.find(attr);
1881 if (cached_it != op_exec_info->cached_dtypes.end()) {
1882 return GetPythonObjectFromInt(cached_it->second);
1883 }
1884
1885 auto it = op_exec_info->attr_to_inputs_map->find(attr);
1886 if (it == op_exec_info->attr_to_inputs_map->end()) {
1887 // No other inputs - this should never happen.
1888 Py_RETURN_NONE;
1889 }
1890
1891 for (const auto& input_info : it->second) {
1892 PyObject* item = PyTuple_GET_ITEM(
1893 op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i);
1894 if (input_info.is_list) {
1895 for (int i = 0; i < PySequence_Fast_GET_SIZE(item); i++) {
1896 auto* dtype = MaybeGetDType(PySequence_Fast_GET_ITEM(item, i));
1897 if (dtype != nullptr) return dtype;
1898 }
1899 } else {
1900 auto* dtype = MaybeGetDType(item);
1901 if (dtype != nullptr) return dtype;
1902 }
1903 }
1904
1905 Py_RETURN_NONE;
1906 }
1907
1908 // TODO(agarwal): use an automatic mechanism for handling None arguments to
1909 // gradient functions.
1910
1911 // Returns a pair where the first value of the pair indicates whether or not all
1912 // outputs are unused. If the first value is false, the second value is a
1913 // set that identifies which of the output indices are unused.
1914 bool OpGradientDoesntRequireOutputIndices(
1915 const string& op_name,
1916 std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
1917 static tensorflow::gtl::FlatMap<
1918 string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
1919 new tensorflow::gtl::FlatMap<
1920 string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
1921 // Ops that don't require any outputs.
1922 {"Identity", {true, {}}},
1923 {"MatMul", {true, {}}},
1924 {"Conv2DBackpropInput", {true, {}}},
1925 {"Conv2DBackpropFilter", {true, {}}},
1926 {"Conv3D", {true, {}}},
1927 {"Conv3DBackpropInputV2", {true, {}}},
1928 {"AvgPool3D", {true, {}}},
1929 {"AvgPool3DGrad", {true, {}}},
1930 {"MaxPool3D", {false, {}}},
1931 {"MaxPool3DGrad", {true, {}}},
1932 {"MaxPool3DGradGrad", {true, {}}},
1933 {"BiasAdd", {true, {}}},
1934 {"BiasAddV1", {true, {}}},
1935 {"BiasAddGrad", {true, {}}},
1936 {"Softplus", {true, {}}},
1937 {"SoftplusGrad", {true, {}}},
1938 {"Softsign", {true, {}}},
1939 {"ReluGrad", {true, {}}},
1940 {"LeakyRelu", {true, {}}},
1941 {"LeakyReluGrad", {true, {}}},
1942 {"Conv2D", {true, {}}},
1943 {"DepthwiseConv2dNative", {true, {}}},
1944 {"Dilation2D", {true, {}}},
1945 {"AvgPool", {true, {}}},
1946 {"AvgPoolGrad", {true, {}}},
1947 {"BatchNormWithGlobalNormalization", {true, {}}},
1948 {"L2Loss", {true, {}}},
1949 {"Sum", {true, {}}},
1950 {"Prod", {true, {}}},
1951 {"SegmentSum", {true, {}}},
1952 {"SegmentMean", {true, {}}},
1953 {"SparseSegmentSum", {true, {}}},
1954 {"SparseSegmentMean", {true, {}}},
1955 {"SparseSegmentSqrtN", {true, {}}},
1956 {"UnsortedSegmentSum", {true, {}}},
1957 {"UnsortedSegmentMax", {true, {}}},
1958 {"Abs", {true, {}}},
1959 {"Neg", {true, {}}},
1960 {"ReciprocalGrad", {true, {}}},
1961 {"Square", {true, {}}},
1962 {"Expm1", {true, {}}},
1963 {"Log", {true, {}}},
1964 {"Log1p", {true, {}}},
1965 {"TanhGrad", {true, {}}},
1966 {"SigmoidGrad", {true, {}}},
1967 {"Sign", {true, {}}},
1968 {"Sin", {true, {}}},
1969 {"Cos", {true, {}}},
1970 {"Tan", {true, {}}},
1971 {"Add", {true, {}}},
1972 {"Sub", {true, {}}},
1973 {"Mul", {true, {}}},
1974 {"Div", {true, {}}},
1975 {"RealDiv", {true, {}}},
1976 {"Maximum", {true, {}}},
1977 {"Minimum", {true, {}}},
1978 {"SquaredDifference", {true, {}}},
1979 {"Select", {true, {}}},
1980 {"SparseMatMul", {true, {}}},
1981 {"BatchMatMul", {true, {}}},
1982 {"Complex", {true, {}}},
1983 {"Real", {true, {}}},
1984 {"Imag", {true, {}}},
1985 {"Angle", {true, {}}},
1986 {"Conj", {true, {}}},
1987 {"Cast", {true, {}}},
1988 {"Cross", {true, {}}},
1989 {"Cumsum", {true, {}}},
1990 {"Cumprod", {true, {}}},
1991 {"ReadVariableOp", {true, {}}},
1992 {"VarHandleOp", {true, {}}},
1993 {"Shape", {true, {}}},
1994 {"StridedSlice", {true, {}}},
1995 {"Fill", {true, {}}},
1996
1997 // Ops that don't require a subset of outputs.
1998 {"FusedBatchNorm", {false, {0, 1, 2}}},
1999 });
2000
2001 auto it = m->find(op_name);
2002
2003 if (it == m->end()) return false;
2004
2005 *output = &it->second;
2006 return true;
2007 }
2008
2009 // Returns a pair where the first value of the pair indicates whether or not all
2010 // inputs are unused. If the first value is false, the second value is a
2011 // set that identifies which of the input indices are unused.
2012 bool OpGradientDoesntRequireInputIndices(
2013 const string& op_name,
2014 std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
2015 static tensorflow::gtl::FlatMap<
2016 string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
2017 new tensorflow::gtl::FlatMap<
2018 string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
2019 // Ops that don't require any inputs.
2020 {"Identity", {true, {}}},
2021 {"Softmax", {true, {}}},
2022 {"LogSoftmax", {true, {}}},
2023 {"BiasAdd", {true, {}}},
2024 {"Relu", {true, {}}},
2025 {"Relu6", {true, {}}},
2026 {"Elu", {true, {}}},
2027 {"Selu", {true, {}}},
2028 {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
2029 {"Neg", {true, {}}},
2030 {"Inv", {true, {}}},
2031 {"Reciprocal", {true, {}}},
2032 {"Sqrt", {true, {}}},
2033 {"Exp", {true, {}}},
2034 {"Tanh", {true, {}}},
2035 {"Sigmoid", {true, {}}},
2036 {"Real", {true, {}}},
2037 {"Imag", {true, {}}},
2038 {"Conj", {true, {}}},
2039 {"ReadVariableOp", {true, {}}},
2040 {"VarHandleOp", {true, {}}},
2041 {"Shape", {true, {}}},
2042 {"Fill", {true, {}}},
2043
2044 // Ops that don't require a subset of inputs.
2045 {"FusedBatchNorm", {false, {2}}},
2046 });
2047
2048 auto it = m->find(op_name);
2049
2050 if (it == m->end()) return false;
2051
2052 *output = &it->second;
2053 return true;
2054 }
2055
2056 PyObject* CopySequenceSettingIndicesToNull(
2057 PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
2058 tensorflow::Safe_PyObjectPtr fast_seq(
2059 PySequence_Fast(seq, "unable to allocate"));
2060 PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get()));
2061 for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) {
2062 PyObject* item;
2063 if (indices.find(i) != indices.end()) {
2064 item = Py_None;
2065 } else {
2066 item = PySequence_Fast_GET_ITEM(fast_seq.get(), i);
2067 }
2068 Py_INCREF(item);
2069 PyTuple_SET_ITEM(result, i, item);
2070 }
2071 return result;
2072 }
2073
2074 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
2075 PyObject* results, PyObject* name) {
2076 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
2077 if (PyErr_Occurred()) return nullptr;
2078 std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
2079 if (PyErr_Occurred()) return nullptr;
2080
2081 bool should_record = false;
2082 for (TFE_Py_Tape* tape : SafeTapeSet()) {
2083 if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
2084 should_record = true;
2085 break;
2086 }
2087 }
2088 if (!should_record) Py_RETURN_NONE;
2089
2090 string c_op_name = TFE_GetPythonString(op_name);
2091
2092 PyObject* op_outputs;
2093 bool op_outputs_tuple_created = false;
2094 std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
2095
2096 if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
2097 if (outputs_not_required->first) {
2098 op_outputs = Py_None;
2099 } else {
2100 op_outputs_tuple_created = true;
2101 op_outputs = CopySequenceSettingIndicesToNull(
2102 results, outputs_not_required->second);
2103 }
2104 } else {
2105 op_outputs = results;
2106 }
2107
2108 PyObject* op_inputs;
2109 bool op_inputs_tuple_created = false;
2110 std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
2111
2112 if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
2113 if (inputs_not_required->first) {
2114 op_inputs = Py_None;
2115 } else {
2116 op_inputs_tuple_created = true;
2117 op_inputs =
2118 CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
2119 }
2120 } else {
2121 op_inputs = inputs;
2122 }
2123
2124 PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
2125
2126 TapeSetRecordOperation(
2127 op_name, results, input_ids, input_dtypes,
2128 [op_name, attrs, num_inputs, op_inputs, op_outputs]() {
2129 Py_INCREF(op_name);
2130 Py_INCREF(attrs);
2131 Py_INCREF(num_inputs);
2132 Py_INCREF(op_inputs);
2133 Py_INCREF(op_outputs);
2134 PyBackwardFunction* function =
2135 new PyBackwardFunction([op_name, attrs, num_inputs, op_inputs,
2136 op_outputs](PyObject* output_grads) {
2137 if (PyErr_Occurred()) {
2138 return static_cast<PyObject*>(nullptr);
2139 }
2140 tensorflow::Safe_PyObjectPtr callback_args(
2141 Py_BuildValue("OOOOOO", op_name, attrs, num_inputs, op_inputs,
2142 op_outputs, output_grads));
2143
2144 tensorflow::Safe_PyObjectPtr result(
2145 PyObject_CallObject(gradient_function, callback_args.get()));
2146
2147 if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
2148
2149 return tensorflow::swig::Flatten(result.get());
2150 });
2151 return function;
2152 },
2153 [op_name, attrs, num_inputs, op_inputs,
2154 op_outputs](PyBackwardFunction* backward_function) {
2155 Py_DECREF(op_name);
2156 Py_DECREF(attrs);
2157 Py_DECREF(num_inputs);
2158 Py_DECREF(op_inputs);
2159 Py_DECREF(op_outputs);
2160
2161 delete backward_function;
2162 });
2163
2164 Py_DECREF(num_inputs);
2165 if (op_outputs_tuple_created) Py_DECREF(op_outputs);
2166 if (op_inputs_tuple_created) Py_DECREF(op_inputs);
2167
2168 Py_RETURN_NONE;
2169 }
2170
2171 void MaybeNotifyVariableAccessed(PyObject* input) {
2172 DCHECK(CheckResourceVariable(input));
2173 DCHECK(PyObject_HasAttrString(input, "_trainable"));
2174
2175 tensorflow::Safe_PyObjectPtr trainable(
2176 PyObject_GetAttrString(input, "_trainable"));
2177 if (trainable.get() == Py_False) return;
2178 TFE_Py_TapeVariableAccessed(input);
2179 }
2180
2181 bool CastTensor(const FastPathOpExecInfo& op_exec_info,
2182 const TF_DataType& desired_dtype,
2183 tensorflow::Safe_TFE_TensorHandlePtr* handle,
2184 TF_Status* status) {
2185 TF_DataType input_dtype = TFE_TensorHandleDataType(handle->get());
2186 TF_DataType output_dtype = input_dtype;
2187
2188 if (desired_dtype >= 0 && desired_dtype != input_dtype) {
2189 *handle = tensorflow::make_safe(
2190 tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype,
2191 static_cast<TF_DataType>(desired_dtype), status));
2192 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2193 return false;
2194 }
2195 output_dtype = desired_dtype;
2196 }
2197
2198 if (output_dtype != TF_INT32) {
2199 // Note that this is a shallow copy and will share the underlying buffer
2200 // if copying to the same device.
2201 *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
2202 handle->get(), op_exec_info.ctx, op_exec_info.device_name, status));
2203 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2204 return false;
2205 }
2206 }
2207 return true;
2208 }
2209
2210 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
2211 PyObject* input, tensorflow::Safe_PyObjectPtr* output,
2212 TF_Status* status) {
2213 MaybeNotifyVariableAccessed(input);
2214
2215 TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
2216 auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
2217 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
2218
2219 // Set dtype
2220 DCHECK(PyObject_HasAttrString(input, "_dtype"));
2221 tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
2222 int value;
2223 if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
2224 return false;
2225 }
2226 TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
2227
2228 TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
2229 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
2230
2231 // Get handle
2232 tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
2233 if (!EagerTensor_CheckExact(handle.get())) return false;
2234 TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
2235 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
2236
2237 int num_retvals = 1;
2238 TFE_TensorHandle* output_handle;
2239 TFE_Execute(op, &output_handle, &num_retvals, status);
2240 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
2241
2242 // Always create the py object (and correctly DECREF it) from the returned
2243 // value, else the data will leak.
2244 output->reset(EagerTensorFromHandle(output_handle));
2245
2246 // TODO(nareshmodi): Should we run post exec callbacks here?
2247 if (parent_op_exec_info.run_gradient_callback) {
2248 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
2249 PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
2250
2251 tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
2252 Py_INCREF(output->get()); // stay alive after since tuple steals.
2253 PyTuple_SET_ITEM(outputs.get(), 0, output->get());
2254
2255 tensorflow::Safe_PyObjectPtr op_string(
2256 GetPythonObjectFromString("ReadVariableOp"));
2257 if (!RecordGradient(op_string.get(), inputs.get(), Py_None, outputs.get(),
2258 Py_None)) {
2259 return false;
2260 }
2261 }
2262
2263 return true;
2264 }
2265
2266 // Supports 3 cases at the moment:
2267 // i) input is an EagerTensor.
2268 // ii) input is a ResourceVariable - in this case, the is_variable param is
2269 // set to true.
2270 // iii) input is an arbitrary python list/tuple (note, this handling doesn't
2271 // support packing).
2272 //
2273 // NOTE: dtype_hint_getter must *always* return a PyObject that can be
2274 // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
2275 // increfs Py_None).
2276 //
2277 // NOTE: This function sets a python error directly, and returns false.
2278 // TF_Status is only passed since we don't want to have to reallocate it.
2279 bool ConvertToTensor(
2280 const FastPathOpExecInfo& op_exec_info, PyObject* input,
2281 tensorflow::Safe_PyObjectPtr* output_handle,
2282 // This gets a hint for this particular input.
2283 const std::function<PyObject*()>& dtype_hint_getter,
2284 // This sets the dtype after conversion is complete.
2285 const std::function<void(const TF_DataType& dtype)>& dtype_setter,
2286 TF_Status* status) {
2287 if (EagerTensor_CheckExact(input)) {
2288 Py_INCREF(input);
2289 output_handle->reset(input);
2290 return true;
2291 } else if (CheckResourceVariable(input)) {
2292 return ReadVariableOp(op_exec_info, input, output_handle, status);
2293 }
2294
2295 // The hint comes from a supposedly similarly typed tensor.
2296 tensorflow::Safe_PyObjectPtr dtype_hint(dtype_hint_getter());
2297 if (PyErr_Occurred()) {
2298 return false;
2299 }
2300
2301 tensorflow::Safe_TFE_TensorHandlePtr handle =
2302 tensorflow::make_safe(static_cast<TFE_TensorHandle*>(
2303 tensorflow::ConvertToEagerTensor(input, dtype_hint.get())));
2304 if (handle == nullptr) {
2305 return MaybeRaiseExceptionFromTFStatus(status, nullptr);
2306 }
2307
2308 int desired_dtype = -1;
2309 if (dtype_hint.get() != Py_None) {
2310 if (!ParseTypeValue("", dtype_hint.get(), status, &desired_dtype)) {
2311 PyErr_SetString(PyExc_TypeError,
2312 tensorflow::strings::StrCat(
2313 "Expecting a DataType value for dtype. Got ",
2314 Py_TYPE(dtype_hint.get())->tp_name)
2315 .c_str());
2316 return false;
2317 }
2318 }
2319
2320 // Maybe cast to the desired type. This is intended to match python
2321 // convert_to_tensor behavior.
2322 TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get());
2323 if (desired_dtype >= 0 && desired_dtype != output_dtype) {
2324 if (tensorflow::IsCompatible(desired_dtype, output_dtype)) {
2325 if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype),
2326 &handle, status)) {
2327 return false;
2328 }
2329 output_dtype = TFE_TensorHandleDataType(handle.get());
2330 } else {
2331 tensorflow::Safe_PyObjectPtr input_str(PyObject_Str(input));
2332 PyErr_SetString(
2333 PyExc_TypeError,
2334 tensorflow::strings::StrCat(
2335 "Cannot convert provided value to EagerTensor. Provided value: ",
2336 TFE_GetPythonString(input_str.get()), " Requested dtype: ",
2337 tensorflow::DataTypeString(
2338 static_cast<tensorflow::DataType>(desired_dtype)))
2339 .c_str());
2340 return false;
2341 }
2342 }
2343
2344 output_handle->reset(EagerTensorFromHandle(handle.release()));
2345 dtype_setter(output_dtype);
2346
2347 return true;
2348 }
2349
2350 // Adds input and type attr to the op, and to the list of flattened
2351 // inputs/attrs.
2352 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
2353 const bool add_type_attr,
2354 const tensorflow::OpDef::ArgDef& input_arg,
2355 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
2356 std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
2357 TFE_Op* op, TF_Status* status) {
2358 // py_eager_tensor's ownership is transferred to flattened_inputs if it is
2359 // required, else the object is destroyed and DECREF'd when the object goes
2360 // out of scope in this function.
2361 tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
2362
2363 if (!ConvertToTensor(
2364 *op_exec_info, input, &py_eager_tensor,
2365 [&]() {
2366 if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
2367 return GetPythonObjectFromInt(input_arg.type());
2368 }
2369 return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
2370 },
2371 [&](const TF_DataType dtype) {
2372 op_exec_info->cached_dtypes[input_arg.type_attr()] =
2373 static_cast<tensorflow::DataType>(dtype);
2374 },
2375 status)) {
2376 return false;
2377 }
2378
2379 TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
2380
2381 if (add_type_attr && !input_arg.type_attr().empty()) {
2382 auto dtype = TFE_TensorHandleDataType(input_handle);
2383 TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
2384 if (flattened_attrs != nullptr) {
2385 flattened_attrs->emplace_back(
2386 GetPythonObjectFromString(input_arg.type_attr().data()));
2387 flattened_attrs->emplace_back(PyLong_FromLong(dtype));
2388 }
2389 }
2390
2391 if (flattened_inputs != nullptr) {
2392 flattened_inputs->emplace_back(std::move(py_eager_tensor));
2393 }
2394
2395 TFE_OpAddInput(op, input_handle, status);
2396 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2397 return false;
2398 }
2399
2400 return true;
2401 }
2402
2403 const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) {
2404 const char* op_name = TFE_GetPythonString(py_op_name);
2405 if (op_name == nullptr) {
2406 PyErr_SetString(PyExc_TypeError,
2407 Printf("expected a string for op_name, got %s instead",
2408 py_op_name->ob_type->tp_name)
2409 .c_str());
2410 return nullptr;
2411 }
2412
2413 const tensorflow::OpRegistrationData* op_reg_data = nullptr;
2414 const tensorflow::Status lookup_status =
2415 tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
2416 if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) {
2417 return nullptr;
2418 }
2419 return &op_reg_data->op_def;
2420 }
2421
2422 const char* GetDeviceName(PyObject* py_device_name) {
2423 if (py_device_name != Py_None) {
2424 return TFE_GetPythonString(py_device_name);
2425 }
2426 return nullptr;
2427 }
2428
2429 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
2430 if (!PySequence_Check(seq)) {
2431 PyErr_SetString(PyExc_TypeError,
2432 Printf("expected a sequence for attr %s, got %s instead",
2433 attr_name.data(), seq->ob_type->tp_name)
2434 .data());
2435
2436 return false;
2437 }
2438 return true;
2439 }
2440
2441 bool RunCallbacks(
2442 const FastPathOpExecInfo& op_exec_info, PyObject* args,
2443 const std::vector<tensorflow::Safe_PyObjectPtr>* const flattened_inputs,
2444 const std::vector<tensorflow::Safe_PyObjectPtr>* const flattened_attrs,
2445 PyObject* flattened_result) {
2446 if (!op_exec_info.run_callbacks) return true;
2447
2448 tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs->size()));
2449 for (int i = 0; i < flattened_inputs->size(); i++) {
2450 PyObject* input = (*flattened_inputs)[i].get();
2451 Py_INCREF(input);
2452 PyTuple_SET_ITEM(inputs.get(), i, input);
2453 }
2454
2455 int num_non_inferred_attrs = PyTuple_GET_SIZE(args) -
2456 op_exec_info.op_def->input_arg_size() -
2457 kFastPathExecuteInputStartIndex;
2458 int num_attrs = flattened_attrs->size() + num_non_inferred_attrs;
2459 tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
2460
2461 for (int i = 0; i < num_non_inferred_attrs; i++) {
2462 auto* attr =
2463 PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex +
2464 op_exec_info.op_def->input_arg_size() + i);
2465 Py_INCREF(attr);
2466 PyTuple_SET_ITEM(attrs.get(), i, attr);
2467 }
2468 for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
2469 PyObject* attr_or_name =
2470 flattened_attrs->at(i - num_non_inferred_attrs).get();
2471 Py_INCREF(attr_or_name);
2472 PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
2473 }
2474
2475 if (op_exec_info.run_gradient_callback) {
2476 if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
2477 flattened_result, op_exec_info.name)) {
2478 return false;
2479 }
2480 }
2481
2482 if (op_exec_info.run_post_exec_callbacks) {
2483 tensorflow::Safe_PyObjectPtr callback_args(
2484 Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
2485 flattened_result, op_exec_info.name));
2486 for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
2487 PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
2488 if (!PyCallable_Check(callback_fn)) {
2489 PyErr_SetString(
2490 PyExc_TypeError,
2491 Printf("expected a function for "
2492 "post execution callback in index %ld, got %s instead",
2493 i, callback_fn->ob_type->tp_name)
2494 .c_str());
2495 return false;
2496 }
2497 PyObject* callback_result =
2498 PyObject_CallObject(callback_fn, callback_args.get());
2499 if (!callback_result) {
2500 return false;
2501 }
2502 Py_DECREF(callback_result);
2503 }
2504 }
2505
2506 return true;
2507 }
2508
2509 } // namespace
2510
2511 PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
2512 Py_ssize_t args_size = PyTuple_GET_SIZE(args);
2513 if (args_size < kFastPathExecuteInputStartIndex) {
2514 PyErr_SetString(
2515 PyExc_ValueError,
2516 Printf("There must be at least %d items in the input tuple.",
2517 kFastPathExecuteInputStartIndex)
2518 .c_str());
2519 return nullptr;
2520 }
2521
2522 FastPathOpExecInfo op_exec_info;
2523
2524 op_exec_info.ctx = reinterpret_cast<TFE_Context*>(
2525 PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
2526 op_exec_info.args = args;
2527
2528 if (op_exec_info.ctx == nullptr) {
2529 // The context hasn't been initialized. It will be in the slow path.
2530 RaiseFallbackException(
2531 "This function does not handle the case of the path where "
2532 "all inputs are not already EagerTensors.");
2533 return nullptr;
2534 }
2535
2536 op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
2537 op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
2538 op_exec_info.op_def = GetOpDef(op_exec_info.op_name);
2539 if (op_exec_info.op_def == nullptr) return nullptr;
2540 op_exec_info.name = PyTuple_GET_ITEM(args, 3);
2541 op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4);
2542
2543 const tensorflow::OpDef* op_def = op_exec_info.op_def;
2544
2545 // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
2546 // (similar to benchmark_tf_gradient_function_*). Also consider using an
2547 // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
2548 // point out problems with heap allocs.
2549 op_exec_info.run_gradient_callback =
2550 !*ThreadTapeIsStopped() && !GetTapeSet()->empty();
2551 op_exec_info.run_post_exec_callbacks =
2552 op_exec_info.callbacks != Py_None &&
2553 PyList_Size(op_exec_info.callbacks) > 0;
2554 op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
2555 op_exec_info.run_post_exec_callbacks;
2556
2557 if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
2558 PyErr_SetString(
2559 PyExc_ValueError,
2560 Printf("Tuple size smaller than intended. Expected to be at least %d, "
2561 "was %ld",
2562 kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
2563 args_size)
2564 .c_str());
2565 return nullptr;
2566 }
2567
2568 if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) {
2569 RaiseFallbackException(
2570 "This function does not handle the case of the path where "
2571 "all inputs are not already EagerTensors.");
2572 return nullptr;
2573 }
2574
2575 op_exec_info.attr_to_inputs_map = GetAttrToInputsMap(*op_def);
2576
2577 TF_Status* status = TF_NewStatus();
2578 TFE_Op* op = TFE_NewOp(op_exec_info.ctx, op_def->name().c_str(), status);
2579 auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
2580 TF_DeleteStatus(status);
2581 TFE_DeleteOp(op);
2582 });
2583 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2584 return nullptr;
2585 }
2586
2587 // Mapping of attr name to size - used to calculate the number of values
2588 // to be expected by the TFE_Execute run.
2589 tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
2590
2591 // Set non-inferred attrs, including setting defaults if the attr is passed in
2592 // as None.
2593 for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
2594 i < args_size; i += 2) {
2595 PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
2596 const tensorflow::StringPiece attr_name(TFE_GetPythonString(py_attr_name));
2597 PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
2598
2599 // Not creating an index since most of the time there are not more than a
2600 // few attrs.
2601 // TODO(nareshmodi): Maybe include the index as part of the
2602 // OpRegistrationData.
2603 for (const auto& attr : op_def->attr()) {
2604 if (attr_name == attr.name()) {
2605 SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr_name.data(),
2606 py_attr_value, &attr_list_sizes, status);
2607
2608 if (TF_GetCode(status) != TF_OK) {
2609 VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
2610 << "\" since we are unable to set the value for attr \""
2611 << attr.name() << "\" due to: " << TF_Message(status);
2612 RaiseFallbackException(TF_Message(status));
2613 return nullptr;
2614 }
2615
2616 break;
2617 }
2618 }
2619 }
2620
2621 TFE_OpSetDevice(op, op_exec_info.device_name, status);
2622 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2623 return nullptr;
2624 }
2625
2626 // Flat attrs and inputs as required by the record_gradient call. The attrs
2627 // here only contain inferred attrs (non-inferred attrs are added directly
2628 // from the input args).
2629 // All items in flattened_attrs and flattened_inputs contain
2630 // Safe_PyObjectPtr - any time something steals a reference to this, it must
2631 // INCREF.
2632 // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
2633 // directly.
2634 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
2635 nullptr;
2636 std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
2637 nullptr;
2638
2639 // TODO(nareshmodi): Encapsulate callbacks information into a struct.
2640 if (op_exec_info.run_callbacks) {
2641 flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
2642 flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
2643 }
2644
2645 // Add inferred attrs and inputs.
2646 // The following code might set duplicate type attrs. This will result in
2647 // the CacheKey for the generated AttrBuilder possibly differing from
2648 // those where the type attrs are correctly set. Inconsistent CacheKeys
2649 // for ops means that there might be unnecessarily duplicated kernels.
2650 // TODO(nareshmodi): Fix this.
2651 for (int i = 0; i < op_def->input_arg_size(); i++) {
2652 const auto& input_arg = op_def->input_arg(i);
2653
2654 PyObject* input =
2655 PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
2656 if (!input_arg.number_attr().empty()) {
2657 // The item is a homogeneous list.
2658 if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
2659 tensorflow::Safe_PyObjectPtr fast_input(
2660 PySequence_Fast(input, "Could not parse sequence."));
2661 if (fast_input.get() == nullptr) {
2662 return nullptr;
2663 }
2664 Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
2665
2666 TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
2667 if (op_exec_info.run_callbacks) {
2668 flattened_attrs->emplace_back(
2669 GetPythonObjectFromString(input_arg.number_attr().data()));
2670 flattened_attrs->emplace_back(PyLong_FromLong(len));
2671 }
2672 attr_list_sizes[input_arg.number_attr()] = len;
2673
2674 if (len > 0) {
2675 // First item adds the type attr.
2676 if (!AddInputToOp(&op_exec_info,
2677 PySequence_Fast_GET_ITEM(fast_input.get(), 0), true,
2678 input_arg, flattened_attrs.get(),
2679 flattened_inputs.get(), op, status)) {
2680 return nullptr;
2681 }
2682
2683 for (Py_ssize_t j = 1; j < len; j++) {
2684 // Since the list is homogeneous, we don't need to re-add the attr.
2685 if (!AddInputToOp(&op_exec_info,
2686 PySequence_Fast_GET_ITEM(fast_input.get(), j),
2687 false, input_arg, nullptr /* flattened_attrs */,
2688 flattened_inputs.get(), op, status)) {
2689 return nullptr;
2690 }
2691 }
2692 }
2693 } else if (!input_arg.type_list_attr().empty()) {
2694 // The item is a heterogeneous list.
2695 if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
2696 return nullptr;
2697 }
2698 const string& attr_name = input_arg.type_list_attr();
2699 Py_ssize_t len = PySequence_Fast_GET_SIZE(input);
2700 tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
2701 PyObject* py_attr_value = nullptr;
2702 if (op_exec_info.run_callbacks) {
2703 py_attr_value = PyTuple_New(len);
2704 }
2705 for (Py_ssize_t j = 0; j < len; j++) {
2706 PyObject* py_input = PySequence_Fast_GET_ITEM(input, j);
2707 tensorflow::Safe_PyObjectPtr py_eager_tensor;
2708 if (!ConvertToTensor(
2709 op_exec_info, py_input, &py_eager_tensor,
2710 []() { Py_RETURN_NONE; }, [](const TF_DataType& dtype) {},
2711 status)) {
2712 return nullptr;
2713 }
2714
2715 TFE_TensorHandle* input_handle =
2716 EagerTensor_Handle(py_eager_tensor.get());
2717
2718 attr_value[j] = TFE_TensorHandleDataType(input_handle);
2719
2720 TFE_OpAddInput(op, input_handle, status);
2721 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
2722 return nullptr;
2723 }
2724
2725 if (op_exec_info.run_callbacks) {
2726 flattened_inputs->emplace_back(std::move(py_eager_tensor));
2727
2728 PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
2729 }
2730 }
2731 if (op_exec_info.run_callbacks) {
2732 flattened_attrs->emplace_back(
2733 GetPythonObjectFromString(attr_name.data()));
2734 flattened_attrs->emplace_back(py_attr_value);
2735 }
2736 TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
2737 attr_value.size());
2738 attr_list_sizes[attr_name] = len;
2739 } else {
2740 // The item is a single item.
2741 if (!AddInputToOp(&op_exec_info, input, true, input_arg,
2742 flattened_attrs.get(), flattened_inputs.get(), op,
2743 status)) {
2744 return nullptr;
2745 }
2746 }
2747 }
2748
2749 int num_retvals = 0;
2750 for (int i = 0; i < op_def->output_arg_size(); i++) {
2751 const auto& output_arg = op_def->output_arg(i);
2752 int delta = 1;
2753 if (!output_arg.number_attr().empty()) {
2754 delta = attr_list_sizes[output_arg.number_attr()];
2755 } else if (!output_arg.type_list_attr().empty()) {
2756 delta = attr_list_sizes[output_arg.type_list_attr()];
2757 }
2758 if (delta < 0) {
2759 RaiseFallbackException(
2760 "Attributes suggest that the size of an output list is less than 0");
2761 return nullptr;
2762 }
2763 num_retvals += delta;
2764 }
2765
2766 tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
2767
2768 Py_BEGIN_ALLOW_THREADS;
2769 TFE_Execute(op, retvals.data(), &num_retvals, status);
2770 Py_END_ALLOW_THREADS;
2771
2772 if (TF_GetCode(status) != TF_OK) {
2773 // Augment the status with the op_name for easier debugging similar to
2774 // TFE_Py_Execute.
2775 TF_SetStatus(status, TF_GetCode(status),
2776 tensorflow::strings::StrCat(
2777 TF_Message(status),
2778 " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]")
2779 .c_str());
2780
2781 MaybeRaiseExceptionFromTFStatus(status, nullptr);
2782 return nullptr;
2783 }
2784
2785 tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
2786 for (int i = 0; i < num_retvals; ++i) {
2787 PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
2788 }
2789
2790 if (!RunCallbacks(op_exec_info, args, flattened_inputs.get(),
2791 flattened_attrs.get(), flat_result.get())) {
2792 return nullptr;
2793 }
2794
2795 // Unflatten results.
2796 if (op_def->output_arg_size() == 0) {
2797 Py_RETURN_NONE;
2798 }
2799
2800 if (op_def->output_arg_size() == 1) {
2801 if (!op_def->output_arg(0).number_attr().empty() ||
2802 !op_def->output_arg(0).type_list_attr().empty()) {
2803 return flat_result.release();
2804 } else {
2805 auto* result = PyList_GET_ITEM(flat_result.get(), 0);
2806 Py_INCREF(result);
2807 return result;
2808 }
2809 }
2810
2811 // Correctly output the results that are made into a namedtuple.
2812 PyObject* result = PyList_New(op_def->output_arg_size());
2813 int flat_result_index = 0;
2814 for (int i = 0; i < op_def->output_arg_size(); i++) {
2815 if (!op_def->output_arg(i).number_attr().empty()) {
2816 int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
2817 PyObject* inner_list = PyList_New(list_length);
2818 for (int j = 0; j < list_length; j++) {
2819 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
2820 Py_INCREF(obj);
2821 PyList_SET_ITEM(inner_list, j, obj);
2822 }
2823 PyList_SET_ITEM(result, i, inner_list);
2824 } else if (!op_def->output_arg(i).type_list_attr().empty()) {
2825 int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
2826 PyObject* inner_list = PyList_New(list_length);
2827 for (int j = 0; j < list_length; j++) {
2828 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
2829 Py_INCREF(obj);
2830 PyList_SET_ITEM(inner_list, j, obj);
2831 }
2832 PyList_SET_ITEM(result, i, inner_list);
2833 } else {
2834 PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
2835 Py_INCREF(obj);
2836 PyList_SET_ITEM(result, i, obj);
2837 }
2838 }
2839 return result;
2840 }
2841
2842 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
2843 PyObject* attrs, PyObject* results,
2844 PyObject* name) {
2845 if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
2846 Py_RETURN_NONE;
2847 }
2848
2849 return RecordGradient(op_name, inputs, attrs, results, name);
2850 }
2851
2852 namespace {
2853 const char kTensor[] = "T";
2854 const char kIndexedSlices[] = "I";
2855 const char kList[] = "L";
2856 const char kListEnd[] = "l";
2857 const char kTuple[] = "U";
2858 const char kTupleEnd[] = "u";
2859 const char kDict[] = "D";
2860 const char kRaw[] = "R";
2861 const char kShape[] = "s";
2862 const char kShapeDelim[] = "-";
2863 const char kDType[] = "d";
2864 const char kNone[] = "n";
2865
2866 struct EncodeResult {
2867 string str;
2868 std::vector<PyObject*> objects;
2869
2870 PyObject* ToPyTuple() {
2871 PyObject* result = PyTuple_New(2);
2872
2873 PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str.c_str()));
2874
2875 if (objects.empty()) {
2876 Py_INCREF(Py_None);
2877 PyTuple_SET_ITEM(result, 1, Py_None);
2878 } else {
2879 PyObject* objects_tuple = PyTuple_New(objects.size());
2880
2881 for (int i = 0; i < objects.size(); i++) {
2882 PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
2883 }
2884
2885 PyTuple_SET_ITEM(result, 1, objects_tuple);
2886 }
2887
2888 return result;
2889 }
2890 };
2891
2892 tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
2893 bool include_tensor_ranks_only,
2894 EncodeResult* result) {
2895 if (EagerTensor_CheckExact(arg)) {
2896 TFE_TensorHandle* t = EagerTensor_Handle(arg);
2897 tensorflow::TensorShape tensor_shape;
2898 TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
2899
2900 absl::StrAppend(&result->str, kDType, t->handle->dtype);
2901
2902 absl::StrAppend(&result->str, kShape);
2903 if (include_tensor_ranks_only) {
2904 absl::StrAppend(&result->str, tensor_shape.dim_sizes().size());
2905 } else {
2906 for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
2907 absl::StrAppend(&result->str, dim_size, kShapeDelim);
2908 }
2909 }
2910 return tensorflow::Status::OK();
2911 }
2912
2913 tensorflow::Safe_PyObjectPtr dtype_object(
2914 PyObject_GetAttrString(arg, "dtype"));
2915
2916 if (dtype_object == nullptr) {
2917 return tensorflow::errors::InvalidArgument(
2918 "ops.Tensor object doesn't have dtype() attr.");
2919 }
2920
2921 tensorflow::Safe_PyObjectPtr dtype_enum(
2922 PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
2923
2924 if (dtype_enum == nullptr) {
2925 return tensorflow::errors::InvalidArgument(
2926 "ops.Tensor's dtype object doesn't have _type_enum() attr.");
2927 }
2928
2929 tensorflow::DataType dtype =
2930 static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
2931
2932 absl::StrAppend(&result->str, kDType, dtype);
2933
2934 static char _shape_tuple[] = "_shape_tuple";
2935 tensorflow::Safe_PyObjectPtr shape_tuple(
2936 PyObject_CallMethod(arg, _shape_tuple, nullptr));
2937
2938 if (shape_tuple == nullptr) {
2939 return tensorflow::errors::InvalidArgument(
2940 "ops.Tensor object doesn't have _shape_tuple() method.");
2941 }
2942
2943 if (shape_tuple.get() == Py_None) {
2944 // Unknown shape, encode that directly.
2945 absl::StrAppend(&result->str, kNone);
2946 return tensorflow::Status::OK();
2947 }
2948
2949 absl::StrAppend(&result->str, kShape);
2950 tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
2951 shape_tuple.get(), "shape_tuple didn't return a sequence"));
2952
2953 int len = PySequence_Fast_GET_SIZE(shape_seq.get());
2954
2955 if (include_tensor_ranks_only) {
2956 absl::StrAppend(&result->str, len);
2957 } else {
2958 for (int i = 0; i < len; ++i) {
2959 PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
2960 if (item == Py_None) {
2961 absl::StrAppend(&result->str, kNone);
2962 } else {
2963 absl::StrAppend(&result->str, MakeInt(item));
2964 }
2965 }
2966 }
2967 return tensorflow::Status::OK();
2968 }
2969
2970 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
2971 bool include_tensor_ranks_only,
2972 EncodeResult* result);
2973
2974 // This function doesn't set the type of sequence before
2975 tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
2976 const char* end_type,
2977 bool include_tensor_ranks_only,
2978 EncodeResult* result) {
2979 tensorflow::Safe_PyObjectPtr arg_seq(
2980 PySequence_Fast(arg, "unable to create seq from list/tuple"));
2981
2982 absl::StrAppend(&result->str, type);
2983 int len = PySequence_Fast_GET_SIZE(arg_seq.get());
2984 for (int i = 0; i < len; ++i) {
2985 PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i);
2986 if (item == Py_None) {
2987 absl::StrAppend(&result->str, kNone);
2988 } else {
2989 TF_RETURN_IF_ERROR(
2990 TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
2991 }
2992 }
2993 absl::StrAppend(&result->str, end_type);
2994
2995 return tensorflow::Status::OK();
2996 }
2997
2998 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
2999 bool include_tensor_ranks_only,
3000 EncodeResult* result) {
3001 if (tensorflow::swig::IsTensor(arg)) {
3002 absl::StrAppend(&result->str, kTensor);
3003 TF_RETURN_IF_ERROR(
3004 TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
3005 } else if (tensorflow::swig::IsIndexedSlices(arg)) {
3006 absl::StrAppend(&result->str, kIndexedSlices);
3007 tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values"));
3008 if (values == nullptr) {
3009 PyErr_Clear();
3010 return tensorflow::errors::InvalidArgument(
3011 "IndexedSlices does not have a values attr");
3012 }
3013 TF_RETURN_IF_ERROR(
3014 TFE_Py_EncodeTensor(values.get(), include_tensor_ranks_only, result));
3015
3016 tensorflow::Safe_PyObjectPtr indices(
3017 PyObject_GetAttrString(arg, "indices"));
3018 if (indices == nullptr) {
3019 PyErr_Clear();
3020 return tensorflow::errors::InvalidArgument(
3021 "IndexedSlices does not have a indices attr");
3022 }
3023 TF_RETURN_IF_ERROR(
3024 TFE_Py_EncodeTensor(indices.get(), include_tensor_ranks_only, result));
3025
3026 tensorflow::Safe_PyObjectPtr dense_shape(
3027 PyObject_GetAttrString(arg, "dense_shape"));
3028 if (dense_shape == nullptr) {
3029 PyErr_Clear();
3030 return tensorflow::errors::InvalidArgument(
3031 "IndexedSlices does not have a dense_shape attr");
3032 }
3033 if (dense_shape.get() != Py_None) {
3034 TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(
3035 dense_shape.get(), include_tensor_ranks_only, result));
3036 }
3037 } else if (PyList_Check(arg)) {
3038 TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
3039 arg, kList, kListEnd, include_tensor_ranks_only, result));
3040 } else if (PyTuple_Check(arg)) {
3041 TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
3042 arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
3043 } else if (PyDict_Check(arg)) {
3044 tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
3045 if (PyList_Sort(keys.get()) == -1) {
3046 return tensorflow::errors::Internal("Unable to sort keys");
3047 }
3048
3049 absl::StrAppend(&result->str, kDict);
3050 int len = PyList_Size(keys.get());
3051
3052 for (int i = 0; i < len; i++) {
3053 PyObject* key = PyList_GetItem(keys.get(), i);
3054 TF_RETURN_IF_ERROR(
3055 TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
3056 PyObject* value = PyDict_GetItem(arg, key);
3057 TF_RETURN_IF_ERROR(
3058 TFE_Py_EncodeArgHelper(value, include_tensor_ranks_only, result));
3059 }
3060 } else {
3061 PyObject* object = PyWeakref_NewRef(arg, nullptr);
3062
3063 if (object == nullptr) {
3064 PyErr_Clear();
3065
3066 object = arg;
3067 Py_INCREF(object);
3068 }
3069
3070 absl::StrAppend(&result->str, kRaw);
3071 result->objects.push_back(object);
3072 }
3073
3074 return tensorflow::Status::OK();
3075 }
3076
3077 } // namespace
3078
3079 // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
3080 // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
3081 // are used for both performance reasons, as much TensorFlow code specializes
3082 // on known shapes to produce slimmer graphs, and correctness, as some
3083 // high-level APIs require shapes to be fully-known.
3084 //
3085 // `include_tensor_ranks_only` allows caching on arguments excluding shape info,
3086 // so that a slow path using relaxed shape can rely on a cache key that excludes
3087 // shapes.
3088 //
3089 // TODO(nareshmodi): Add support for sparse tensors.
3090 PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
3091 EncodeResult result;
3092 const auto status =
3093 TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
3094 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
3095 return nullptr;
3096 }
3097
3098 return result.ToPyTuple();
3099 }
3100
3101 // A method prints incoming messages directly to Python's
3102 // stdout using Python's C API. This is necessary in Jupyter notebooks
3103 // and colabs where messages to the C stdout don't go to the notebook
3104 // cell outputs, but calls to Python's stdout do.
3105 void PrintToPythonStdout(const char* msg) {
3106 if (Py_IsInitialized()) {
3107 PyGILState_STATE py_threadstate;
3108 py_threadstate = PyGILState_Ensure();
3109
3110 string string_msg = msg;
3111 // PySys_WriteStdout truncates strings over 1000 bytes, so
3112 // we write the message in chunks small enough to not be truncated.
3113 int CHUNK_SIZE = 900;
3114 auto len = string_msg.length();
3115 for (int i = 0; i < len; i += CHUNK_SIZE) {
3116 PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
3117 }
3118 PySys_WriteStdout("\n");
3119
3120 PyGILState_Release(py_threadstate);
3121 }
3122 }
3123
3124 // Register PrintToPythonStdout as a log listener, to allow
3125 // printing in colabs and jupyter notebooks to work.
3126 void TFE_Py_EnableInteractivePythonLogging() {
3127 static bool enabled_interactive_logging = false;
3128 if (!enabled_interactive_logging) {
3129 enabled_interactive_logging = true;
3130 TF_RegisterLogListener(PrintToPythonStdout);
3131 }
3132 }
3133