• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7   http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 using System;
16 using System.Runtime.InteropServices;
17 
18 using TfLiteInterpreter = System.IntPtr;
19 using TfLiteInterpreterOptions = System.IntPtr;
20 using TfLiteModel = System.IntPtr;
21 using TfLiteTensor = System.IntPtr;
22 
23 namespace TensorFlowLite
24 {
25   /// <summary>
26   /// Simple C# bindings for the experimental TensorFlowLite C API.
27   /// </summary>
28   public class Interpreter : IDisposable
29   {
30     private const string TensorFlowLibrary = "tensorflowlite_c";
31 
32     private TfLiteModel model;
33     private TfLiteInterpreter interpreter;
34 
Interpreter(byte[] modelData)35     public Interpreter(byte[] modelData) {
36       GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned);
37       IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject();
38       model = TfLiteModelCreate(modelDataPtr, modelData.Length);
39       if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model");
40       interpreter = TfLiteInterpreterCreate(model, /*options=*/IntPtr.Zero);
41       if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
42     }
43 
~Interpreter()44     ~Interpreter() {
45       Dispose();
46     }
47 
Dispose()48     public void Dispose() {
49       if (interpreter != IntPtr.Zero) TfLiteInterpreterDelete(interpreter);
50       interpreter = IntPtr.Zero;
51       if (model != IntPtr.Zero) TfLiteModelDelete(model);
52       model = IntPtr.Zero;
53     }
54 
Invoke()55     public void Invoke() {
56       ThrowIfError(TfLiteInterpreterInvoke(interpreter));
57     }
58 
GetInputTensorCount()59     public int GetInputTensorCount() {
60       return TfLiteInterpreterGetInputTensorCount(interpreter);
61     }
62 
SetInputTensorData(int inputTensorIndex, Array inputTensorData)63     public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) {
64       GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned);
65       IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
66       TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, inputTensorIndex);
67       ThrowIfError(TfLiteTensorCopyFromBuffer(
68           tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData)));
69     }
70 
ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape)71     public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) {
72       ThrowIfError(TfLiteInterpreterResizeInputTensor(
73           interpreter, inputTensorIndex, inputTensorShape, inputTensorShape.Length));
74     }
75 
AllocateTensors()76     public void AllocateTensors() {
77       ThrowIfError(TfLiteInterpreterAllocateTensors(interpreter));
78     }
79 
GetOutputTensorCount()80     public int GetOutputTensorCount() {
81       return TfLiteInterpreterGetOutputTensorCount(interpreter);
82     }
83 
GetOutputTensorData(int outputTensorIndex, Array outputTensorData)84     public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) {
85       GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned);
86       IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
87       TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, outputTensorIndex);
88       ThrowIfError(TfLiteTensorCopyToBuffer(
89           tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData)));
90     }
91 
ThrowIfError(int resultCode)92     private static void ThrowIfError(int resultCode) {
93       if (resultCode != 0) throw new Exception("TensorFlowLite operation failed.");
94     }
95 
96     #region Externs
97 
98     [DllImport (TensorFlowLibrary)]
TfLiteModelCreate(IntPtr model_data, int model_size)99     private static extern unsafe TfLiteInterpreter TfLiteModelCreate(IntPtr model_data, int model_size);
100 
101     [DllImport (TensorFlowLibrary)]
TfLiteModelDelete(TfLiteModel model)102     private static extern unsafe TfLiteInterpreter TfLiteModelDelete(TfLiteModel model);
103 
104     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterCreate( TfLiteModel model, TfLiteInterpreterOptions optional_options)105     private static extern unsafe TfLiteInterpreter TfLiteInterpreterCreate(
106         TfLiteModel model,
107         TfLiteInterpreterOptions optional_options);
108 
109     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterDelete(TfLiteInterpreter interpreter)110     private static extern unsafe void TfLiteInterpreterDelete(TfLiteInterpreter interpreter);
111 
112     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterGetInputTensorCount( TfLiteInterpreter interpreter)113     private static extern unsafe int TfLiteInterpreterGetInputTensorCount(
114         TfLiteInterpreter interpreter);
115 
116     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterGetInputTensor( TfLiteInterpreter interpreter, int input_index)117     private static extern unsafe TfLiteTensor TfLiteInterpreterGetInputTensor(
118         TfLiteInterpreter interpreter,
119         int input_index);
120 
121     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterResizeInputTensor( TfLiteInterpreter interpreter, int input_index, int[] input_dims, int input_dims_size)122     private static extern unsafe int TfLiteInterpreterResizeInputTensor(
123         TfLiteInterpreter interpreter,
124         int input_index,
125         int[] input_dims,
126         int input_dims_size);
127 
128     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterAllocateTensors( TfLiteInterpreter interpreter)129     private static extern unsafe int TfLiteInterpreterAllocateTensors(
130         TfLiteInterpreter interpreter);
131 
132     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterInvoke(TfLiteInterpreter interpreter)133     private static extern unsafe int TfLiteInterpreterInvoke(TfLiteInterpreter interpreter);
134 
135     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterGetOutputTensorCount( TfLiteInterpreter interpreter)136     private static extern unsafe int TfLiteInterpreterGetOutputTensorCount(
137         TfLiteInterpreter interpreter);
138 
139     [DllImport (TensorFlowLibrary)]
TfLiteInterpreterGetOutputTensor( TfLiteInterpreter interpreter, int output_index)140     private static extern unsafe TfLiteTensor TfLiteInterpreterGetOutputTensor(
141         TfLiteInterpreter interpreter,
142         int output_index);
143 
144     [DllImport (TensorFlowLibrary)]
TfLiteTensorCopyFromBuffer( TfLiteTensor tensor, IntPtr input_data, int input_data_size)145     private static extern unsafe int TfLiteTensorCopyFromBuffer(
146         TfLiteTensor tensor,
147         IntPtr input_data,
148         int input_data_size);
149 
150     [DllImport (TensorFlowLibrary)]
TfLiteTensorCopyToBuffer( TfLiteTensor tensor, IntPtr output_data, int output_data_size)151     private static extern unsafe int TfLiteTensorCopyToBuffer(
152         TfLiteTensor tensor,
153         IntPtr output_data,
154         int output_data_size);
155 
156     #endregion
157   }
158 }
159