• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Disallow Numpy 1.7 deprecated symbols.
17 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
18 
19 #include "numpy/arrayobject.h"
20 #include "numpy/ufuncobject.h"
21 #include "pybind11/chrono.h"
22 #include "pybind11/complex.h"
23 #include "pybind11/functional.h"
24 #include "pybind11/pybind11.h"
25 #include "pybind11/stl.h"
26 #include "tensorflow/c/checkpoint_reader.h"
27 #include "tensorflow/c/tf_status.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/python/lib/core/ndarray_tensor.h"
31 #include "tensorflow/python/lib/core/py_exception_registry.h"
32 #include "tensorflow/python/lib/core/pybind11_lib.h"
33 #include "tensorflow/python/lib/core/pybind11_status.h"
34 #include "tensorflow/python/lib/core/safe_ptr.h"
35 
36 namespace py = pybind11;
37 
38 // TODO(amitpatankar): Move the custom type casters to separate common header
39 // only libraries.
40 
41 namespace pybind11 {
42 namespace detail {
43 
44 /* This is a custom type caster for the TensorShape object. For more
45  * documentation please refer to this link:
46  * https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html#custom-type-casters
47  * The PyCheckpointReader methods sometimes return the `TensorShape` object
48  * and the `DataType` object as outputs. This custom type caster helps Python
49  * handle it's conversion from C++ to Python. Since we do not accept these
50  * classes as arguments from Python, it is not necessary to define the `load`
51  * function to cast the object from Python to a C++ object.
52  */
53 
54 template <>
55 struct type_caster<tensorflow::TensorShape> {
56  public:
57   PYBIND11_TYPE_CASTER(tensorflow::TensorShape, _("tensorflow::TensorShape"));
58 
castpybind11::detail::type_caster59   static handle cast(const tensorflow::TensorShape& src,
60                      return_value_policy unused_policy, handle unused_handle) {
61     // TODO(amitpatankar): Simplify handling TensorShape as output later.
62     size_t dims = src.dims();
63     tensorflow::Safe_PyObjectPtr value(PyList_New(dims));
64     for (size_t i = 0; i < dims; ++i) {
65 #if PY_MAJOR_VERSION >= 3
66       tensorflow::Safe_PyObjectPtr dim_value(
67           tensorflow::make_safe(PyLong_FromLong(src.dim_size(i))));
68 #else
69       tensorflow::Safe_PyObjectPtr dim_value(
70           tensorflow::make_safe(PyInt_FromLong(src.dim_size(i))));
71 #endif
72       PyList_SET_ITEM(value.get(), i, dim_value.release());
73     }
74 
75     return value.release();
76   }
77 };
78 
79 template <>
80 struct type_caster<tensorflow::DataType> {
81  public:
82   PYBIND11_TYPE_CASTER(tensorflow::DataType, _("tensorflow::DataType"));
83 
castpybind11::detail::type_caster84   static handle cast(const tensorflow::DataType& src,
85                      return_value_policy unused_policy, handle unused_handle) {
86 #if PY_MAJOR_VERSION >= 3
87     tensorflow::Safe_PyObjectPtr value(
88         tensorflow::make_safe(PyLong_FromLong(src)));
89 #else
90     tensorflow::Safe_PyObjectPtr value(
91         tensorflow::make_safe(PyInt_FromLong(src)));
92 #endif
93     return value.release();
94   }
95 };
96 
97 }  // namespace detail
98 }  // namespace pybind11
99 
100 namespace tensorflow {
101 
CheckpointReader_GetTensor(tensorflow::checkpoint::CheckpointReader * reader,const string & name)102 static py::object CheckpointReader_GetTensor(
103     tensorflow::checkpoint::CheckpointReader* reader, const string& name) {
104   Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
105   PyObject* py_obj = Py_None;
106   std::unique_ptr<tensorflow::Tensor> tensor;
107   reader->GetTensor(name, &tensor, status.get());
108 
109   // Error handling if unable to get Tensor.
110   tensorflow::MaybeRaiseFromTFStatus(status.get());
111 
112   tensorflow::MaybeRaiseFromStatus(
113       tensorflow::TensorToNdarray(*tensor, &py_obj));
114 
115   return tensorflow::PyoOrThrow(
116       PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj)));
117 }
118 
119 }  // namespace tensorflow
120 
PYBIND11_MODULE(_pywrap_checkpoint_reader,m)121 PYBIND11_MODULE(_pywrap_checkpoint_reader, m) {
122   // Initialization code to use numpy types in the type casters.
123   import_array1();
124   py::class_<tensorflow::checkpoint::CheckpointReader> checkpoint_reader_class(
125       m, "CheckpointReader");
126   checkpoint_reader_class
127       .def(py::init([](const std::string& filename) {
128         tensorflow::Safe_TF_StatusPtr status =
129             tensorflow::make_safe(TF_NewStatus());
130         // pybind11 support smart pointers and will own freeing the memory when
131         // complete.
132         // https://pybind11.readthedocs.io/en/master/advanced/smart_ptrs.html#std-unique-ptr
133         auto checkpoint =
134             std::make_unique<tensorflow::checkpoint::CheckpointReader>(
135                 filename, status.get());
136         tensorflow::MaybeRaiseFromTFStatus(status.get());
137         return checkpoint;
138       }))
139       .def("debug_string",
140            [](tensorflow::checkpoint::CheckpointReader& self) {
141              return py::bytes(self.DebugString());
142            })
143       .def("get_variable_to_shape_map",
144            &tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap)
145       .def("_GetVariableToDataTypeMap",
146            &tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap)
147       .def("_HasTensor", &tensorflow::checkpoint::CheckpointReader::HasTensor)
148       .def_static("CheckpointReader_GetTensor",
149                   &tensorflow::CheckpointReader_GetTensor);
150 };
151