1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Disallow Numpy 1.7 deprecated symbols.
17 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
18
19 #include "numpy/arrayobject.h"
20 #include "numpy/ufuncobject.h"
21 #include "pybind11/chrono.h"
22 #include "pybind11/complex.h"
23 #include "pybind11/functional.h"
24 #include "pybind11/pybind11.h"
25 #include "pybind11/stl.h"
26 #include "tensorflow/c/checkpoint_reader.h"
27 #include "tensorflow/c/tf_status.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/python/lib/core/ndarray_tensor.h"
31 #include "tensorflow/python/lib/core/py_exception_registry.h"
32 #include "tensorflow/python/lib/core/pybind11_lib.h"
33 #include "tensorflow/python/lib/core/pybind11_status.h"
34 #include "tensorflow/python/lib/core/safe_ptr.h"
35
36 namespace py = pybind11;
37
38 // TODO(amitpatankar): Move the custom type casters to separate common header
39 // only libraries.
40
41 namespace pybind11 {
42 namespace detail {
43
44 /* This is a custom type caster for the TensorShape object. For more
45 * documentation please refer to this link:
46 * https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html#custom-type-casters
47 * The PyCheckpointReader methods sometimes return the `TensorShape` object
48 * and the `DataType` object as outputs. This custom type caster helps Python
49 * handle it's conversion from C++ to Python. Since we do not accept these
50 * classes as arguments from Python, it is not necessary to define the `load`
51 * function to cast the object from Python to a C++ object.
52 */
53
54 template <>
55 struct type_caster<tensorflow::TensorShape> {
56 public:
57 PYBIND11_TYPE_CASTER(tensorflow::TensorShape, _("tensorflow::TensorShape"));
58
castpybind11::detail::type_caster59 static handle cast(const tensorflow::TensorShape& src,
60 return_value_policy unused_policy, handle unused_handle) {
61 // TODO(amitpatankar): Simplify handling TensorShape as output later.
62 size_t dims = src.dims();
63 tensorflow::Safe_PyObjectPtr value(PyList_New(dims));
64 for (size_t i = 0; i < dims; ++i) {
65 #if PY_MAJOR_VERSION >= 3
66 tensorflow::Safe_PyObjectPtr dim_value(
67 tensorflow::make_safe(PyLong_FromLong(src.dim_size(i))));
68 #else
69 tensorflow::Safe_PyObjectPtr dim_value(
70 tensorflow::make_safe(PyInt_FromLong(src.dim_size(i))));
71 #endif
72 PyList_SET_ITEM(value.get(), i, dim_value.release());
73 }
74
75 return value.release();
76 }
77 };
78
79 template <>
80 struct type_caster<tensorflow::DataType> {
81 public:
82 PYBIND11_TYPE_CASTER(tensorflow::DataType, _("tensorflow::DataType"));
83
castpybind11::detail::type_caster84 static handle cast(const tensorflow::DataType& src,
85 return_value_policy unused_policy, handle unused_handle) {
86 #if PY_MAJOR_VERSION >= 3
87 tensorflow::Safe_PyObjectPtr value(
88 tensorflow::make_safe(PyLong_FromLong(src)));
89 #else
90 tensorflow::Safe_PyObjectPtr value(
91 tensorflow::make_safe(PyInt_FromLong(src)));
92 #endif
93 return value.release();
94 }
95 };
96
97 } // namespace detail
98 } // namespace pybind11
99
100 namespace tensorflow {
101
CheckpointReader_GetTensor(tensorflow::checkpoint::CheckpointReader * reader,const string & name)102 static py::object CheckpointReader_GetTensor(
103 tensorflow::checkpoint::CheckpointReader* reader, const string& name) {
104 Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
105 PyObject* py_obj = Py_None;
106 std::unique_ptr<tensorflow::Tensor> tensor;
107 reader->GetTensor(name, &tensor, status.get());
108
109 // Error handling if unable to get Tensor.
110 tensorflow::MaybeRaiseFromTFStatus(status.get());
111
112 tensorflow::MaybeRaiseFromStatus(
113 tensorflow::TensorToNdarray(*tensor, &py_obj));
114
115 return tensorflow::PyoOrThrow(
116 PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj)));
117 }
118
119 } // namespace tensorflow
120
PYBIND11_MODULE(_pywrap_checkpoint_reader,m)121 PYBIND11_MODULE(_pywrap_checkpoint_reader, m) {
122 // Initialization code to use numpy types in the type casters.
123 import_array1();
124 py::class_<tensorflow::checkpoint::CheckpointReader> checkpoint_reader_class(
125 m, "CheckpointReader");
126 checkpoint_reader_class
127 .def(py::init([](const std::string& filename) {
128 tensorflow::Safe_TF_StatusPtr status =
129 tensorflow::make_safe(TF_NewStatus());
130 // pybind11 support smart pointers and will own freeing the memory when
131 // complete.
132 // https://pybind11.readthedocs.io/en/master/advanced/smart_ptrs.html#std-unique-ptr
133 auto checkpoint =
134 std::make_unique<tensorflow::checkpoint::CheckpointReader>(
135 filename, status.get());
136 tensorflow::MaybeRaiseFromTFStatus(status.get());
137 return checkpoint;
138 }))
139 .def("debug_string",
140 [](tensorflow::checkpoint::CheckpointReader& self) {
141 return py::bytes(self.DebugString());
142 })
143 .def("get_variable_to_shape_map",
144 &tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap)
145 .def("_GetVariableToDataTypeMap",
146 &tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap)
147 .def("_HasTensor", &tensorflow::checkpoint::CheckpointReader::HasTensor)
148 .def_static("CheckpointReader_GetTensor",
149 &tensorflow::CheckpointReader_GetTensor);
150 };
151