• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "pybind11/functional.h"
17 #include "pybind11/pybind11.h"
18 #include "pybind11/pytypes.h"
19 #include "pybind11/stl.h"
20 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
21 #include "tensorflow/python/lib/core/pybind11_lib.h"
22 
23 namespace py = pybind11;
24 using tflite::interpreter_wrapper::InterpreterWrapper;
25 
PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper,m)26 PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
27   m.doc() = R"pbdoc(
28     _pywrap_tensorflow_interpreter_wrapper
29     -----
30   )pbdoc";
31 
32   // pybind11 suggests to convert factory functions into constructors, but
33   // when bytes are provided the wrapper will be confused which
34   // constructor to call.
35   m.def("CreateWrapperFromFile",
36         [](const std::string& model_path, int op_resolver_id,
37            const std::vector<std::string>& registerers,
38            bool preserve_all_tensors) {
39           std::string error;
40           auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
41               model_path.c_str(), op_resolver_id, registerers, &error,
42               preserve_all_tensors);
43           if (!wrapper) {
44             throw std::invalid_argument(error);
45           }
46           return wrapper;
47         });
48   m.def(
49       "CreateWrapperFromFile",
50       [](const std::string& model_path, int op_resolver_id,
51          const std::vector<std::string>& registerers_by_name,
52          const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
53          bool preserve_all_tensors) {
54         std::string error;
55         auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
56             model_path.c_str(), op_resolver_id, registerers_by_name,
57             registerers_by_func, &error, preserve_all_tensors);
58         if (!wrapper) {
59           throw std::invalid_argument(error);
60         }
61         return wrapper;
62       });
63   m.def("CreateWrapperFromBuffer",
64         [](const py::bytes& data, int op_resolver_id,
65            const std::vector<std::string>& registerers,
66            bool preserve_all_tensors) {
67           std::string error;
68           auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
69               data.ptr(), op_resolver_id, registerers, &error,
70               preserve_all_tensors);
71           if (!wrapper) {
72             throw std::invalid_argument(error);
73           }
74           return wrapper;
75         });
76   m.def(
77       "CreateWrapperFromBuffer",
78       [](const py::bytes& data, int op_resolver_id,
79          const std::vector<std::string>& registerers_by_name,
80          const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
81          bool preserve_all_tensors) {
82         std::string error;
83         auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
84             data.ptr(), op_resolver_id, registerers_by_name,
85             registerers_by_func, &error, preserve_all_tensors);
86         if (!wrapper) {
87           throw std::invalid_argument(error);
88         }
89         return wrapper;
90       });
91   py::class_<InterpreterWrapper>(m, "InterpreterWrapper")
92       .def(
93           "AllocateTensors",
94           [](InterpreterWrapper& self, int subgraph_index) {
95             return tensorflow::PyoOrThrow(self.AllocateTensors(subgraph_index));
96           },
97           py::arg("subgraph_index") = 0)
98       .def(
99           "Invoke",
100           [](InterpreterWrapper& self, int subgraph_index) {
101             return tensorflow::PyoOrThrow(self.Invoke(subgraph_index));
102           },
103           py::arg("subgraph_index") = 0)
104       .def("InputIndices",
105            [](const InterpreterWrapper& self) {
106              return tensorflow::PyoOrThrow(self.InputIndices());
107            })
108       .def("OutputIndices",
109            [](InterpreterWrapper& self) {
110              return tensorflow::PyoOrThrow(self.OutputIndices());
111            })
112       .def(
113           "ResizeInputTensor",
114           [](InterpreterWrapper& self, int i, py::handle& value, bool strict,
115              int subgraph_index) {
116             return tensorflow::PyoOrThrow(
117                 self.ResizeInputTensor(i, value.ptr(), strict, subgraph_index));
118           },
119           py::arg("i"), py::arg("value"), py::arg("strict"),
120           py::arg("subgraph_index") = 0)
121       .def("NumTensors", &InterpreterWrapper::NumTensors)
122       .def("TensorName", &InterpreterWrapper::TensorName)
123       .def("TensorType",
124            [](const InterpreterWrapper& self, int i) {
125              return tensorflow::PyoOrThrow(self.TensorType(i));
126            })
127       .def("TensorSize",
128            [](const InterpreterWrapper& self, int i) {
129              return tensorflow::PyoOrThrow(self.TensorSize(i));
130            })
131       .def("TensorSizeSignature",
132            [](const InterpreterWrapper& self, int i) {
133              return tensorflow::PyoOrThrow(self.TensorSizeSignature(i));
134            })
135       .def("TensorSparsityParameters",
136            [](const InterpreterWrapper& self, int i) {
137              return tensorflow::PyoOrThrow(self.TensorSparsityParameters(i));
138            })
139       .def(
140           "TensorQuantization",
141           [](const InterpreterWrapper& self, int i) {
142             return tensorflow::PyoOrThrow(self.TensorQuantization(i));
143           },
144           R"pbdoc(
145             Deprecated in favor of TensorQuantizationParameters.
146           )pbdoc")
147       .def(
148           "TensorQuantizationParameters",
149           [](InterpreterWrapper& self, int i) {
150             return tensorflow::PyoOrThrow(self.TensorQuantizationParameters(i));
151           })
152       .def(
153           "SetTensor",
154           [](InterpreterWrapper& self, int i, py::handle& value,
155              int subgraph_index) {
156             return tensorflow::PyoOrThrow(
157                 self.SetTensor(i, value.ptr(), subgraph_index));
158           },
159           py::arg("i"), py::arg("value"), py::arg("subgraph_index") = 0)
160       .def(
161           "GetTensor",
162           [](const InterpreterWrapper& self, int tensor_index,
163              int subgraph_index) {
164             return tensorflow::PyoOrThrow(
165                 self.GetTensor(tensor_index, subgraph_index));
166           },
167           py::arg("tensor_index"), py::arg("subgraph_index") = 0)
168       .def("GetSubgraphIndexFromSignature",
169            [](InterpreterWrapper& self, const char* signature_key) {
170              return tensorflow::PyoOrThrow(
171                  self.GetSubgraphIndexFromSignature(signature_key));
172            })
173       .def("GetSignatureDefs",
174            [](InterpreterWrapper& self) {
175              return tensorflow::PyoOrThrow(self.GetSignatureDefs());
176            })
177       .def("ResetVariableTensors",
178            [](InterpreterWrapper& self) {
179              return tensorflow::PyoOrThrow(self.ResetVariableTensors());
180            })
181       .def("NumNodes", &InterpreterWrapper::NumNodes)
182       .def("NodeName", &InterpreterWrapper::NodeName)
183       .def("NodeInputs",
184            [](const InterpreterWrapper& self, int i) {
185              return tensorflow::PyoOrThrow(self.NodeInputs(i));
186            })
187       .def("NodeOutputs",
188            [](const InterpreterWrapper& self, int i) {
189              return tensorflow::PyoOrThrow(self.NodeOutputs(i));
190            })
191       .def(
192           "tensor",
193           [](InterpreterWrapper& self, py::handle& base_object,
194              int tensor_index, int subgraph_index) {
195             return tensorflow::PyoOrThrow(
196                 self.tensor(base_object.ptr(), tensor_index, subgraph_index));
197           },
198           R"pbdoc(
199             Returns a reference to tensor index as a numpy array from subgraph.
200             The base_object should be the interpreter object providing the
201             memory.
202           )pbdoc",
203           py::arg("base_object"), py::arg("tensor_index"),
204           py::arg("subgraph_index") = 0)
205       .def(
206           "ModifyGraphWithDelegate",
207           // Address of the delegate is passed as an argument.
208           [](InterpreterWrapper& self, uintptr_t delegate_ptr) {
209             return tensorflow::PyoOrThrow(self.ModifyGraphWithDelegate(
210                 reinterpret_cast<TfLiteDelegate*>(delegate_ptr)));
211           },
212           R"pbdoc(
213             Adds a delegate to the interpreter.
214           )pbdoc")
215       .def(
216           "SetNumThreads",
217           [](InterpreterWrapper& self, int num_threads) {
218             return tensorflow::PyoOrThrow(self.SetNumThreads(num_threads));
219           },
220           R"pbdoc(
221              ask the interpreter to set the number of threads to use.
222           )pbdoc")
223       .def("interpreter", [](InterpreterWrapper& self) {
224         return reinterpret_cast<intptr_t>(self.interpreter());
225       });
226 }
227