1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
17 #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
18
19 #include <Python.h>
20
21 #include "pybind11/pybind11.h"
22 #include "tensorflow/c/tf_status.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/error_codes.pb.h"
25 #include "tensorflow/python/lib/core/py_exception_registry.h"
26
27 namespace tensorflow {
28
29 namespace internal {
30
CodeToPyExc(const int code)31 inline PyObject* CodeToPyExc(const int code) {
32 switch (code) {
33 case error::Code::INVALID_ARGUMENT:
34 return PyExc_ValueError;
35 case error::Code::OUT_OF_RANGE:
36 return PyExc_IndexError;
37 case error::Code::UNIMPLEMENTED:
38 return PyExc_NotImplementedError;
39 default:
40 return PyExc_RuntimeError;
41 }
42 }
43
StatusToPyExc(const Status & status)44 inline PyObject* StatusToPyExc(const Status& status) {
45 return CodeToPyExc(status.code());
46 }
47
TFStatusToPyExc(const TF_Status * status)48 inline PyObject* TFStatusToPyExc(const TF_Status* status) {
49 return CodeToPyExc(TF_GetCode(status));
50 }
51
52 } // namespace internal
53
MaybeRaiseFromStatus(const Status & status)54 inline void MaybeRaiseFromStatus(const Status& status) {
55 if (!status.ok()) {
56 PyErr_SetString(internal::StatusToPyExc(status),
57 status.error_message().c_str());
58 throw pybind11::error_already_set();
59 }
60 }
61
MaybeRaiseRegisteredFromStatus(const tensorflow::Status & status)62 inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) {
63 if (!status.ok()) {
64 PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
65 pybind11::make_tuple(pybind11::none(), pybind11::none(),
66 status.error_message())
67 .ptr());
68 throw pybind11::error_already_set();
69 }
70 }
71
MaybeRaiseRegisteredFromStatusWithGIL(const tensorflow::Status & status)72 inline void MaybeRaiseRegisteredFromStatusWithGIL(
73 const tensorflow::Status& status) {
74 if (!status.ok()) {
75 // Acquire GIL for throwing exception.
76 pybind11::gil_scoped_acquire acquire;
77
78 PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
79 pybind11::make_tuple(pybind11::none(), pybind11::none(),
80 status.error_message())
81 .ptr());
82 throw pybind11::error_already_set();
83 }
84 }
85
MaybeRaiseFromTFStatus(TF_Status * status)86 inline void MaybeRaiseFromTFStatus(TF_Status* status) {
87 TF_Code code = TF_GetCode(status);
88 if (code != TF_OK) {
89 PyErr_SetString(internal::TFStatusToPyExc(status), TF_Message(status));
90 throw pybind11::error_already_set();
91 }
92 }
93
MaybeRaiseRegisteredFromTFStatus(TF_Status * status)94 inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
95 TF_Code code = TF_GetCode(status);
96 if (code != TF_OK) {
97 PyErr_SetObject(PyExceptionRegistry::Lookup(code),
98 pybind11::make_tuple(pybind11::none(), pybind11::none(),
99 TF_Message(status))
100 .ptr());
101 throw pybind11::error_already_set();
102 }
103 }
104
MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status * status)105 inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
106 TF_Code code = TF_GetCode(status);
107 if (code != TF_OK) {
108 // Acquire GIL for throwing exception.
109 pybind11::gil_scoped_acquire acquire;
110
111 PyErr_SetObject(PyExceptionRegistry::Lookup(code),
112 pybind11::make_tuple(pybind11::none(), pybind11::none(),
113 TF_Message(status))
114 .ptr());
115 throw pybind11::error_already_set();
116 }
117 }
118
119 } // namespace tensorflow
120
121 namespace pybind11 {
122 namespace detail {
123
124 // Raise an exception if a given status is not OK, otherwise return None.
125 //
126 // The correspondence between status codes and exception classes is given
127 // by PyExceptionRegistry. Note that the registry should be initialized
128 // in order to be used, see PyExceptionRegistry::Init.
129 template <>
130 struct type_caster<tensorflow::Status> {
131 public:
132 PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status"));
133 static handle cast(tensorflow::Status status, return_value_policy, handle) {
134 tensorflow::MaybeRaiseFromStatus(status);
135 return none().inc_ref();
136 }
137 };
138
139 } // namespace detail
140 } // namespace pybind11
141
142 #endif // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
143