• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
17 #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
18 
19 #include <Python.h>
20 
21 #include "pybind11/pybind11.h"
22 #include "tensorflow/c/tf_status.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/error_codes.pb.h"
25 #include "tensorflow/python/lib/core/py_exception_registry.h"
26 
27 namespace tensorflow {
28 
29 namespace internal {
30 
CodeToPyExc(const int code)31 inline PyObject* CodeToPyExc(const int code) {
32   switch (code) {
33     case error::Code::INVALID_ARGUMENT:
34       return PyExc_ValueError;
35     case error::Code::OUT_OF_RANGE:
36       return PyExc_IndexError;
37     case error::Code::UNIMPLEMENTED:
38       return PyExc_NotImplementedError;
39     default:
40       return PyExc_RuntimeError;
41   }
42 }
43 
StatusToPyExc(const Status & status)44 inline PyObject* StatusToPyExc(const Status& status) {
45   return CodeToPyExc(status.code());
46 }
47 
TFStatusToPyExc(const TF_Status * status)48 inline PyObject* TFStatusToPyExc(const TF_Status* status) {
49   return CodeToPyExc(TF_GetCode(status));
50 }
51 
StatusPayloadToDict(const Status & status)52 inline pybind11::dict StatusPayloadToDict(const Status& status) {
53   pybind11::dict dict;
54   const auto& payloads = status.GetAllPayloads();
55   for (auto& pair : payloads) {
56     dict[pair.first.c_str()] = pair.second.c_str();
57   }
58   return dict;
59 }
60 
61 }  // namespace internal
62 
MaybeRaiseFromStatus(const Status & status)63 inline void MaybeRaiseFromStatus(const Status& status) {
64   if (!status.ok()) {
65     PyErr_SetString(internal::StatusToPyExc(status),
66                     status.error_message().c_str());
67     throw pybind11::error_already_set();
68   }
69 }
70 
SetRegisteredErrFromStatus(const tensorflow::Status & status)71 inline void SetRegisteredErrFromStatus(const tensorflow::Status& status) {
72   PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
73                   pybind11::make_tuple(pybind11::none(), pybind11::none(),
74                                        status.error_message(),
75                                        internal::StatusPayloadToDict(status))
76                       .ptr());
77 }
78 
MaybeRaiseRegisteredFromStatus(const tensorflow::Status & status)79 inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) {
80   if (!status.ok()) {
81     SetRegisteredErrFromStatus(status);
82     throw pybind11::error_already_set();
83   }
84 }
85 
MaybeRaiseRegisteredFromStatusWithGIL(const tensorflow::Status & status)86 inline void MaybeRaiseRegisteredFromStatusWithGIL(
87     const tensorflow::Status& status) {
88   if (!status.ok()) {
89     // Acquire GIL for throwing exception.
90     pybind11::gil_scoped_acquire acquire;
91     SetRegisteredErrFromStatus(status);
92     throw pybind11::error_already_set();
93   }
94 }
95 
MaybeRaiseFromTFStatus(TF_Status * status)96 inline void MaybeRaiseFromTFStatus(TF_Status* status) {
97   TF_Code code = TF_GetCode(status);
98   if (code != TF_OK) {
99     PyErr_SetString(internal::TFStatusToPyExc(status), TF_Message(status));
100     throw pybind11::error_already_set();
101   }
102 }
103 
MaybeRaiseRegisteredFromTFStatus(TF_Status * status)104 inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
105   TF_Code code = TF_GetCode(status);
106   if (code != TF_OK) {
107     PyErr_SetObject(PyExceptionRegistry::Lookup(code),
108                     pybind11::make_tuple(pybind11::none(), pybind11::none(),
109                                          TF_Message(status))
110                         .ptr());
111     throw pybind11::error_already_set();
112   }
113 }
114 
MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status * status)115 inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
116   TF_Code code = TF_GetCode(status);
117   if (code != TF_OK) {
118     // Acquire GIL for throwing exception.
119     pybind11::gil_scoped_acquire acquire;
120 
121     PyErr_SetObject(PyExceptionRegistry::Lookup(code),
122                     pybind11::make_tuple(pybind11::none(), pybind11::none(),
123                                          TF_Message(status))
124                         .ptr());
125     throw pybind11::error_already_set();
126   }
127 }
128 
129 }  // namespace tensorflow
130 
131 namespace pybind11 {
132 namespace detail {
133 
134 // Raise an exception if a given status is not OK, otherwise return None.
135 //
136 // The correspondence between status codes and exception classes is given
137 // by PyExceptionRegistry. Note that the registry should be initialized
138 // in order to be used, see PyExceptionRegistry::Init.
139 template <>
140 struct type_caster<tensorflow::Status> {
141  public:
142   PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status"));
143   static handle cast(tensorflow::Status status, return_value_policy, handle) {
144     tensorflow::MaybeRaiseFromStatus(status);
145     return none().inc_ref();
146   }
147 };
148 
149 }  // namespace detail
150 }  // namespace pybind11
151 
152 #endif  // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
153