1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 using System; 16 using System.Runtime.InteropServices; 17 18 using TfLiteInterpreter = System.IntPtr; 19 using TfLiteInterpreterOptions = System.IntPtr; 20 using TfLiteModel = System.IntPtr; 21 using TfLiteTensor = System.IntPtr; 22 23 namespace TensorFlowLite 24 { 25 /// <summary> 26 /// Simple C# bindings for the experimental TensorFlowLite C API. 27 /// </summary> 28 public class Interpreter : IDisposable 29 { 30 private const string TensorFlowLibrary = "tensorflowlite_c"; 31 32 private TfLiteModel model; 33 private TfLiteInterpreter interpreter; 34 Interpreter(byte[] modelData)35 public Interpreter(byte[] modelData) { 36 GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned); 37 IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject(); 38 model = TfLiteModelCreate(modelDataPtr, modelData.Length); 39 if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model"); 40 interpreter = TfLiteInterpreterCreate(model, /*options=*/IntPtr.Zero); 41 if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); 42 } 43 ~Interpreter()44 ~Interpreter() { 45 Dispose(); 46 } 47 Dispose()48 public void Dispose() { 49 if (interpreter != IntPtr.Zero) TfLiteInterpreterDelete(interpreter); 50 interpreter = IntPtr.Zero; 51 if (model != IntPtr.Zero) TfLiteModelDelete(model); 52 model = IntPtr.Zero; 53 } 54 Invoke()55 public void Invoke() { 56 ThrowIfError(TfLiteInterpreterInvoke(interpreter)); 57 } 58 GetInputTensorCount()59 public int GetInputTensorCount() { 60 return TfLiteInterpreterGetInputTensorCount(interpreter); 61 } 62 SetInputTensorData(int inputTensorIndex, Array inputTensorData)63 public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) { 64 GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned); 65 IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); 66 TfLiteTensor tensor = TfLiteInterpreterGetInputTensor(interpreter, inputTensorIndex); 67 ThrowIfError(TfLiteTensorCopyFromBuffer( 68 tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData))); 69 } 70 ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape)71 public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) { 72 ThrowIfError(TfLiteInterpreterResizeInputTensor( 73 interpreter, inputTensorIndex, inputTensorShape, inputTensorShape.Length)); 74 } 75 AllocateTensors()76 public void AllocateTensors() { 77 ThrowIfError(TfLiteInterpreterAllocateTensors(interpreter)); 78 } 79 GetOutputTensorCount()80 public int GetOutputTensorCount() { 81 return TfLiteInterpreterGetOutputTensorCount(interpreter); 82 } 83 GetOutputTensorData(int outputTensorIndex, Array outputTensorData)84 public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) { 85 GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned); 86 IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); 87 TfLiteTensor tensor = TfLiteInterpreterGetOutputTensor(interpreter, outputTensorIndex); 88 ThrowIfError(TfLiteTensorCopyToBuffer( 89 tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData))); 90 } 91 ThrowIfError(int resultCode)92 private static void ThrowIfError(int resultCode) { 93 if (resultCode != 0) throw new Exception("TensorFlowLite operation failed."); 94 } 95 96 #region Externs 97 98 [DllImport (TensorFlowLibrary)] TfLiteModelCreate(IntPtr model_data, int model_size)99 private static extern unsafe TfLiteInterpreter TfLiteModelCreate(IntPtr model_data, int model_size); 100 101 [DllImport (TensorFlowLibrary)] TfLiteModelDelete(TfLiteModel model)102 private static extern unsafe TfLiteInterpreter TfLiteModelDelete(TfLiteModel model); 103 104 [DllImport (TensorFlowLibrary)] TfLiteInterpreterCreate( TfLiteModel model, TfLiteInterpreterOptions optional_options)105 private static extern unsafe TfLiteInterpreter TfLiteInterpreterCreate( 106 TfLiteModel model, 107 TfLiteInterpreterOptions optional_options); 108 109 [DllImport (TensorFlowLibrary)] TfLiteInterpreterDelete(TfLiteInterpreter interpreter)110 private static extern unsafe void TfLiteInterpreterDelete(TfLiteInterpreter interpreter); 111 112 [DllImport (TensorFlowLibrary)] TfLiteInterpreterGetInputTensorCount( TfLiteInterpreter interpreter)113 private static extern unsafe int TfLiteInterpreterGetInputTensorCount( 114 TfLiteInterpreter interpreter); 115 116 [DllImport (TensorFlowLibrary)] TfLiteInterpreterGetInputTensor( TfLiteInterpreter interpreter, int input_index)117 private static extern unsafe TfLiteTensor TfLiteInterpreterGetInputTensor( 118 TfLiteInterpreter interpreter, 119 int input_index); 120 121 [DllImport (TensorFlowLibrary)] TfLiteInterpreterResizeInputTensor( TfLiteInterpreter interpreter, int input_index, int[] input_dims, int input_dims_size)122 private static extern unsafe int TfLiteInterpreterResizeInputTensor( 123 TfLiteInterpreter interpreter, 124 int input_index, 125 int[] input_dims, 126 int input_dims_size); 127 128 [DllImport (TensorFlowLibrary)] TfLiteInterpreterAllocateTensors( TfLiteInterpreter interpreter)129 private static extern unsafe int TfLiteInterpreterAllocateTensors( 130 TfLiteInterpreter interpreter); 131 132 [DllImport (TensorFlowLibrary)] TfLiteInterpreterInvoke(TfLiteInterpreter interpreter)133 private static extern unsafe int TfLiteInterpreterInvoke(TfLiteInterpreter interpreter); 134 135 [DllImport (TensorFlowLibrary)] TfLiteInterpreterGetOutputTensorCount( TfLiteInterpreter interpreter)136 private static extern unsafe int TfLiteInterpreterGetOutputTensorCount( 137 TfLiteInterpreter interpreter); 138 139 [DllImport (TensorFlowLibrary)] TfLiteInterpreterGetOutputTensor( TfLiteInterpreter interpreter, int output_index)140 private static extern unsafe TfLiteTensor TfLiteInterpreterGetOutputTensor( 141 TfLiteInterpreter interpreter, 142 int output_index); 143 144 [DllImport (TensorFlowLibrary)] TfLiteTensorCopyFromBuffer( TfLiteTensor tensor, IntPtr input_data, int input_data_size)145 private static extern unsafe int TfLiteTensorCopyFromBuffer( 146 TfLiteTensor tensor, 147 IntPtr input_data, 148 int input_data_size); 149 150 [DllImport (TensorFlowLibrary)] TfLiteTensorCopyToBuffer( TfLiteTensor tensor, IntPtr output_data, int output_data_size)151 private static extern unsafe int TfLiteTensorCopyToBuffer( 152 TfLiteTensor tensor, 153 IntPtr output_data, 154 int output_data_size); 155 156 #endregion 157 } 158 } 159