1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "pybind11/functional.h"
17 #include "pybind11/pybind11.h"
18 #include "pybind11/pytypes.h"
19 #include "pybind11/stl.h"
20 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
21 #include "tensorflow/python/lib/core/pybind11_lib.h"
22
23 namespace py = pybind11;
24 using tflite::interpreter_wrapper::InterpreterWrapper;
25
PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper,m)26 PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
27 m.doc() = R"pbdoc(
28 _pywrap_tensorflow_interpreter_wrapper
29 -----
30 )pbdoc";
31
32 // pybind11 suggests to convert factory functions into constructors, but
33 // when bytes are provided the wrapper will be confused which
34 // constructor to call.
35 m.def("CreateWrapperFromFile",
36 [](const std::string& model_path, int op_resolver_id,
37 const std::vector<std::string>& registerers,
38 bool preserve_all_tensors) {
39 std::string error;
40 auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
41 model_path.c_str(), op_resolver_id, registerers, &error,
42 preserve_all_tensors);
43 if (!wrapper) {
44 throw std::invalid_argument(error);
45 }
46 return wrapper;
47 });
48 m.def(
49 "CreateWrapperFromFile",
50 [](const std::string& model_path, int op_resolver_id,
51 const std::vector<std::string>& registerers_by_name,
52 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
53 bool preserve_all_tensors) {
54 std::string error;
55 auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
56 model_path.c_str(), op_resolver_id, registerers_by_name,
57 registerers_by_func, &error, preserve_all_tensors);
58 if (!wrapper) {
59 throw std::invalid_argument(error);
60 }
61 return wrapper;
62 });
63 m.def("CreateWrapperFromBuffer",
64 [](const py::bytes& data, int op_resolver_id,
65 const std::vector<std::string>& registerers,
66 bool preserve_all_tensors) {
67 std::string error;
68 auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
69 data.ptr(), op_resolver_id, registerers, &error,
70 preserve_all_tensors);
71 if (!wrapper) {
72 throw std::invalid_argument(error);
73 }
74 return wrapper;
75 });
76 m.def(
77 "CreateWrapperFromBuffer",
78 [](const py::bytes& data, int op_resolver_id,
79 const std::vector<std::string>& registerers_by_name,
80 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
81 bool preserve_all_tensors) {
82 std::string error;
83 auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
84 data.ptr(), op_resolver_id, registerers_by_name,
85 registerers_by_func, &error, preserve_all_tensors);
86 if (!wrapper) {
87 throw std::invalid_argument(error);
88 }
89 return wrapper;
90 });
91 py::class_<InterpreterWrapper>(m, "InterpreterWrapper")
92 .def(
93 "AllocateTensors",
94 [](InterpreterWrapper& self, int subgraph_index) {
95 return tensorflow::PyoOrThrow(self.AllocateTensors(subgraph_index));
96 },
97 py::arg("subgraph_index") = 0)
98 .def(
99 "Invoke",
100 [](InterpreterWrapper& self, int subgraph_index) {
101 return tensorflow::PyoOrThrow(self.Invoke(subgraph_index));
102 },
103 py::arg("subgraph_index") = 0)
104 .def("InputIndices",
105 [](const InterpreterWrapper& self) {
106 return tensorflow::PyoOrThrow(self.InputIndices());
107 })
108 .def("OutputIndices",
109 [](InterpreterWrapper& self) {
110 return tensorflow::PyoOrThrow(self.OutputIndices());
111 })
112 .def(
113 "ResizeInputTensor",
114 [](InterpreterWrapper& self, int i, py::handle& value, bool strict,
115 int subgraph_index) {
116 return tensorflow::PyoOrThrow(
117 self.ResizeInputTensor(i, value.ptr(), strict, subgraph_index));
118 },
119 py::arg("i"), py::arg("value"), py::arg("strict"),
120 py::arg("subgraph_index") = 0)
121 .def("NumTensors", &InterpreterWrapper::NumTensors)
122 .def("TensorName", &InterpreterWrapper::TensorName)
123 .def("TensorType",
124 [](const InterpreterWrapper& self, int i) {
125 return tensorflow::PyoOrThrow(self.TensorType(i));
126 })
127 .def("TensorSize",
128 [](const InterpreterWrapper& self, int i) {
129 return tensorflow::PyoOrThrow(self.TensorSize(i));
130 })
131 .def("TensorSizeSignature",
132 [](const InterpreterWrapper& self, int i) {
133 return tensorflow::PyoOrThrow(self.TensorSizeSignature(i));
134 })
135 .def("TensorSparsityParameters",
136 [](const InterpreterWrapper& self, int i) {
137 return tensorflow::PyoOrThrow(self.TensorSparsityParameters(i));
138 })
139 .def(
140 "TensorQuantization",
141 [](const InterpreterWrapper& self, int i) {
142 return tensorflow::PyoOrThrow(self.TensorQuantization(i));
143 },
144 R"pbdoc(
145 Deprecated in favor of TensorQuantizationParameters.
146 )pbdoc")
147 .def(
148 "TensorQuantizationParameters",
149 [](InterpreterWrapper& self, int i) {
150 return tensorflow::PyoOrThrow(self.TensorQuantizationParameters(i));
151 })
152 .def(
153 "SetTensor",
154 [](InterpreterWrapper& self, int i, py::handle& value,
155 int subgraph_index) {
156 return tensorflow::PyoOrThrow(
157 self.SetTensor(i, value.ptr(), subgraph_index));
158 },
159 py::arg("i"), py::arg("value"), py::arg("subgraph_index") = 0)
160 .def(
161 "GetTensor",
162 [](const InterpreterWrapper& self, int tensor_index,
163 int subgraph_index) {
164 return tensorflow::PyoOrThrow(
165 self.GetTensor(tensor_index, subgraph_index));
166 },
167 py::arg("tensor_index"), py::arg("subgraph_index") = 0)
168 .def("GetSubgraphIndexFromSignature",
169 [](InterpreterWrapper& self, const char* signature_key) {
170 return tensorflow::PyoOrThrow(
171 self.GetSubgraphIndexFromSignature(signature_key));
172 })
173 .def("GetSignatureDefs",
174 [](InterpreterWrapper& self) {
175 return tensorflow::PyoOrThrow(self.GetSignatureDefs());
176 })
177 .def("ResetVariableTensors",
178 [](InterpreterWrapper& self) {
179 return tensorflow::PyoOrThrow(self.ResetVariableTensors());
180 })
181 .def("NumNodes", &InterpreterWrapper::NumNodes)
182 .def("NodeName", &InterpreterWrapper::NodeName)
183 .def("NodeInputs",
184 [](const InterpreterWrapper& self, int i) {
185 return tensorflow::PyoOrThrow(self.NodeInputs(i));
186 })
187 .def("NodeOutputs",
188 [](const InterpreterWrapper& self, int i) {
189 return tensorflow::PyoOrThrow(self.NodeOutputs(i));
190 })
191 .def(
192 "tensor",
193 [](InterpreterWrapper& self, py::handle& base_object,
194 int tensor_index, int subgraph_index) {
195 return tensorflow::PyoOrThrow(
196 self.tensor(base_object.ptr(), tensor_index, subgraph_index));
197 },
198 R"pbdoc(
199 Returns a reference to tensor index as a numpy array from subgraph.
200 The base_object should be the interpreter object providing the
201 memory.
202 )pbdoc",
203 py::arg("base_object"), py::arg("tensor_index"),
204 py::arg("subgraph_index") = 0)
205 .def(
206 "ModifyGraphWithDelegate",
207 // Address of the delegate is passed as an argument.
208 [](InterpreterWrapper& self, uintptr_t delegate_ptr) {
209 return tensorflow::PyoOrThrow(self.ModifyGraphWithDelegate(
210 reinterpret_cast<TfLiteDelegate*>(delegate_ptr)));
211 },
212 R"pbdoc(
213 Adds a delegate to the interpreter.
214 )pbdoc")
215 .def(
216 "SetNumThreads",
217 [](InterpreterWrapper& self, int num_threads) {
218 return tensorflow::PyoOrThrow(self.SetNumThreads(num_threads));
219 },
220 R"pbdoc(
221 ask the interpreter to set the number of threads to use.
222 )pbdoc")
223 .def("interpreter", [](InterpreterWrapper& self) {
224 return reinterpret_cast<intptr_t>(self.interpreter());
225 });
226 }
227