• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7   http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 using System;
16 using System.Runtime.InteropServices;
17 
18 using TFL_Interpreter = System.IntPtr;
19 using TFL_InterpreterOptions = System.IntPtr;
20 using TFL_Model = System.IntPtr;
21 using TFL_Tensor = System.IntPtr;
22 
23 namespace TensorFlowLite
24 {
25   /// <summary>
26   /// Simple C# bindings for the experimental TensorFlowLite C API.
27   /// </summary>
28   public class Interpreter : IDisposable
29   {
30     private const string TensorFlowLibrary = "tensorflowlite_c";
31 
32     private TFL_Model model;
33     private TFL_Interpreter interpreter;
34 
Interpreter(byte[] modelData)35     public Interpreter(byte[] modelData) {
36       GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned);
37       IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject();
38       model = TFL_NewModel(modelDataPtr, modelData.Length);
39       if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model");
40       interpreter = TFL_NewInterpreter(model, /*options=*/IntPtr.Zero);
41       if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
42     }
43 
~Interpreter()44     ~Interpreter() {
45       Dispose();
46     }
47 
Dispose()48     public void Dispose() {
49       if (interpreter != IntPtr.Zero) TFL_DeleteInterpreter(interpreter);
50       interpreter = IntPtr.Zero;
51       if (model != IntPtr.Zero) TFL_DeleteModel(model);
52       model = IntPtr.Zero;
53     }
54 
Invoke()55     public void Invoke() {
56       ThrowIfError(TFL_InterpreterInvoke(interpreter));
57     }
58 
GetInputTensorCount()59     public int GetInputTensorCount() {
60       return TFL_InterpreterGetInputTensorCount(interpreter);
61     }
62 
SetInputTensorData(int inputTensorIndex, Array inputTensorData)63     public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) {
64       GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned);
65       IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
66       TFL_Tensor tensor = TFL_InterpreterGetInputTensor(interpreter, inputTensorIndex);
67       ThrowIfError(TFL_TensorCopyFromBuffer(
68           tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData)));
69     }
70 
ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape)71     public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) {
72       ThrowIfError(TFL_InterpreterResizeInputTensor(
73           interpreter, inputTensorIndex, inputTensorShape, inputTensorShape.Length));
74     }
75 
AllocateTensors()76     public void AllocateTensors() {
77       ThrowIfError(TFL_InterpreterAllocateTensors(interpreter));
78     }
79 
GetOutputTensorCount()80     public int GetOutputTensorCount() {
81       return TFL_InterpreterGetOutputTensorCount(interpreter);
82     }
83 
GetOutputTensorData(int outputTensorIndex, Array outputTensorData)84     public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) {
85       GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned);
86       IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
87       TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(interpreter, outputTensorIndex);
88       ThrowIfError(TFL_TensorCopyToBuffer(
89           tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData)));
90     }
91 
ThrowIfError(int resultCode)92     private static void ThrowIfError(int resultCode) {
93       if (resultCode != 0) throw new Exception("TensorFlowLite operation failed.");
94     }
95 
96     #region Externs
97 
98     [DllImport (TensorFlowLibrary)]
TFL_NewModel(IntPtr model_data, int model_size)99     private static extern unsafe TFL_Interpreter TFL_NewModel(IntPtr model_data, int model_size);
100 
101     [DllImport (TensorFlowLibrary)]
TFL_DeleteModel(TFL_Model model)102     private static extern unsafe TFL_Interpreter TFL_DeleteModel(TFL_Model model);
103 
104     [DllImport (TensorFlowLibrary)]
TFL_NewInterpreter( TFL_Model model, TFL_InterpreterOptions optional_options)105     private static extern unsafe TFL_Interpreter TFL_NewInterpreter(
106         TFL_Model model,
107         TFL_InterpreterOptions optional_options);
108 
109     [DllImport (TensorFlowLibrary)]
TFL_DeleteInterpreter(TFL_Interpreter interpreter)110     private static extern unsafe void TFL_DeleteInterpreter(TFL_Interpreter interpreter);
111 
112     [DllImport (TensorFlowLibrary)]
TFL_InterpreterGetInputTensorCount( TFL_Interpreter interpreter)113     private static extern unsafe int TFL_InterpreterGetInputTensorCount(
114         TFL_Interpreter interpreter);
115 
116     [DllImport (TensorFlowLibrary)]
TFL_InterpreterGetInputTensor( TFL_Interpreter interpreter, int input_index)117     private static extern unsafe TFL_Tensor TFL_InterpreterGetInputTensor(
118         TFL_Interpreter interpreter,
119         int input_index);
120 
121     [DllImport (TensorFlowLibrary)]
TFL_InterpreterResizeInputTensor( TFL_Interpreter interpreter, int input_index, int[] input_dims, int input_dims_size)122     private static extern unsafe int TFL_InterpreterResizeInputTensor(
123         TFL_Interpreter interpreter,
124         int input_index,
125         int[] input_dims,
126         int input_dims_size);
127 
128     [DllImport (TensorFlowLibrary)]
TFL_InterpreterAllocateTensors( TFL_Interpreter interpreter)129     private static extern unsafe int TFL_InterpreterAllocateTensors(
130         TFL_Interpreter interpreter);
131 
132     [DllImport (TensorFlowLibrary)]
TFL_InterpreterInvoke(TFL_Interpreter interpreter)133     private static extern unsafe int TFL_InterpreterInvoke(TFL_Interpreter interpreter);
134 
135     [DllImport (TensorFlowLibrary)]
TFL_InterpreterGetOutputTensorCount( TFL_Interpreter interpreter)136     private static extern unsafe int TFL_InterpreterGetOutputTensorCount(
137         TFL_Interpreter interpreter);
138 
139     [DllImport (TensorFlowLibrary)]
TFL_InterpreterGetOutputTensor( TFL_Interpreter interpreter, int output_index)140     private static extern unsafe TFL_Tensor TFL_InterpreterGetOutputTensor(
141         TFL_Interpreter interpreter,
142         int output_index);
143 
144     [DllImport (TensorFlowLibrary)]
TFL_TensorCopyFromBuffer( TFL_Tensor tensor, IntPtr input_data, int input_data_size)145     private static extern unsafe int TFL_TensorCopyFromBuffer(
146         TFL_Tensor tensor,
147         IntPtr input_data,
148         int input_data_size);
149 
150     [DllImport (TensorFlowLibrary)]
TFL_TensorCopyToBuffer( TFL_Tensor tensor, IntPtr output_data, int output_data_size)151     private static extern unsafe int TFL_TensorCopyToBuffer(
152         TFL_Tensor tensor,
153         IntPtr output_data,
154         int output_data_size);
155 
156     #endregion
157   }
158 }
159