1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
17 #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
18
19 #include <Python.h>
20
21 #include "pybind11/pybind11.h"
22 #include "tensorflow/c/tf_status.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/error_codes.pb.h"
25 #include "tensorflow/python/lib/core/py_exception_registry.h"
26
27 namespace tensorflow {
28
29 namespace internal {
30
CodeToPyExc(const int code)31 inline PyObject* CodeToPyExc(const int code) {
32 switch (code) {
33 case error::Code::INVALID_ARGUMENT:
34 return PyExc_ValueError;
35 case error::Code::OUT_OF_RANGE:
36 return PyExc_IndexError;
37 case error::Code::UNIMPLEMENTED:
38 return PyExc_NotImplementedError;
39 default:
40 return PyExc_RuntimeError;
41 }
42 }
43
StatusToPyExc(const Status & status)44 inline PyObject* StatusToPyExc(const Status& status) {
45 return CodeToPyExc(status.code());
46 }
47
TFStatusToPyExc(const TF_Status * status)48 inline PyObject* TFStatusToPyExc(const TF_Status* status) {
49 return CodeToPyExc(TF_GetCode(status));
50 }
51
StatusPayloadToDict(const Status & status)52 inline pybind11::dict StatusPayloadToDict(const Status& status) {
53 pybind11::dict dict;
54 const auto& payloads = status.GetAllPayloads();
55 for (auto& pair : payloads) {
56 dict[pair.first.c_str()] = pair.second.c_str();
57 }
58 return dict;
59 }
60
61 } // namespace internal
62
MaybeRaiseFromStatus(const Status & status)63 inline void MaybeRaiseFromStatus(const Status& status) {
64 if (!status.ok()) {
65 PyErr_SetString(internal::StatusToPyExc(status),
66 status.error_message().c_str());
67 throw pybind11::error_already_set();
68 }
69 }
70
SetRegisteredErrFromStatus(const tensorflow::Status & status)71 inline void SetRegisteredErrFromStatus(const tensorflow::Status& status) {
72 PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
73 pybind11::make_tuple(pybind11::none(), pybind11::none(),
74 status.error_message(),
75 internal::StatusPayloadToDict(status))
76 .ptr());
77 }
78
MaybeRaiseRegisteredFromStatus(const tensorflow::Status & status)79 inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) {
80 if (!status.ok()) {
81 SetRegisteredErrFromStatus(status);
82 throw pybind11::error_already_set();
83 }
84 }
85
MaybeRaiseRegisteredFromStatusWithGIL(const tensorflow::Status & status)86 inline void MaybeRaiseRegisteredFromStatusWithGIL(
87 const tensorflow::Status& status) {
88 if (!status.ok()) {
89 // Acquire GIL for throwing exception.
90 pybind11::gil_scoped_acquire acquire;
91 SetRegisteredErrFromStatus(status);
92 throw pybind11::error_already_set();
93 }
94 }
95
MaybeRaiseFromTFStatus(TF_Status * status)96 inline void MaybeRaiseFromTFStatus(TF_Status* status) {
97 TF_Code code = TF_GetCode(status);
98 if (code != TF_OK) {
99 PyErr_SetString(internal::TFStatusToPyExc(status), TF_Message(status));
100 throw pybind11::error_already_set();
101 }
102 }
103
MaybeRaiseRegisteredFromTFStatus(TF_Status * status)104 inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
105 TF_Code code = TF_GetCode(status);
106 if (code != TF_OK) {
107 PyErr_SetObject(PyExceptionRegistry::Lookup(code),
108 pybind11::make_tuple(pybind11::none(), pybind11::none(),
109 TF_Message(status))
110 .ptr());
111 throw pybind11::error_already_set();
112 }
113 }
114
MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status * status)115 inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
116 TF_Code code = TF_GetCode(status);
117 if (code != TF_OK) {
118 // Acquire GIL for throwing exception.
119 pybind11::gil_scoped_acquire acquire;
120
121 PyErr_SetObject(PyExceptionRegistry::Lookup(code),
122 pybind11::make_tuple(pybind11::none(), pybind11::none(),
123 TF_Message(status))
124 .ptr());
125 throw pybind11::error_already_set();
126 }
127 }
128
129 } // namespace tensorflow
130
131 namespace pybind11 {
132 namespace detail {
133
134 // Raise an exception if a given status is not OK, otherwise return None.
135 //
136 // The correspondence between status codes and exception classes is given
137 // by PyExceptionRegistry. Note that the registry should be initialized
138 // in order to be used, see PyExceptionRegistry::Init.
139 template <>
140 struct type_caster<tensorflow::Status> {
141 public:
142 PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status"));
143 static handle cast(tensorflow::Status status, return_value_policy, handle) {
144 tensorflow::MaybeRaiseFromStatus(status);
145 return none().inc_ref();
146 }
147 };
148
149 } // namespace detail
150 } // namespace pybind11
151
152 #endif // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
153