1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 using System; 16 using System.Runtime.InteropServices; 17 18 using TFL_Interpreter = System.IntPtr; 19 using TFL_InterpreterOptions = System.IntPtr; 20 using TFL_Model = System.IntPtr; 21 using TFL_Tensor = System.IntPtr; 22 23 namespace TensorFlowLite 24 { 25 /// <summary> 26 /// Simple C# bindings for the experimental TensorFlowLite C API. 27 /// </summary> 28 public class Interpreter : IDisposable 29 { 30 private const string TensorFlowLibrary = "tensorflowlite_c"; 31 32 private TFL_Model model; 33 private TFL_Interpreter interpreter; 34 Interpreter(byte[] modelData)35 public Interpreter(byte[] modelData) { 36 GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned); 37 IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject(); 38 model = TFL_NewModel(modelDataPtr, modelData.Length); 39 if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model"); 40 interpreter = TFL_NewInterpreter(model, /*options=*/IntPtr.Zero); 41 if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); 42 } 43 ~Interpreter()44 ~Interpreter() { 45 Dispose(); 46 } 47 Dispose()48 public void Dispose() { 49 if (interpreter != IntPtr.Zero) TFL_DeleteInterpreter(interpreter); 50 interpreter = IntPtr.Zero; 51 if (model != IntPtr.Zero) TFL_DeleteModel(model); 52 model = IntPtr.Zero; 53 } 54 Invoke()55 public void Invoke() { 56 ThrowIfError(TFL_InterpreterInvoke(interpreter)); 57 } 58 GetInputTensorCount()59 public int GetInputTensorCount() { 60 return TFL_InterpreterGetInputTensorCount(interpreter); 61 } 62 SetInputTensorData(int inputTensorIndex, Array inputTensorData)63 public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) { 64 GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned); 65 IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); 66 TFL_Tensor tensor = TFL_InterpreterGetInputTensor(interpreter, inputTensorIndex); 67 ThrowIfError(TFL_TensorCopyFromBuffer( 68 tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData))); 69 } 70 ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape)71 public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) { 72 ThrowIfError(TFL_InterpreterResizeInputTensor( 73 interpreter, inputTensorIndex, inputTensorShape, inputTensorShape.Length)); 74 } 75 AllocateTensors()76 public void AllocateTensors() { 77 ThrowIfError(TFL_InterpreterAllocateTensors(interpreter)); 78 } 79 GetOutputTensorCount()80 public int GetOutputTensorCount() { 81 return TFL_InterpreterGetOutputTensorCount(interpreter); 82 } 83 GetOutputTensorData(int outputTensorIndex, Array outputTensorData)84 public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) { 85 GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned); 86 IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); 87 TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(interpreter, outputTensorIndex); 88 ThrowIfError(TFL_TensorCopyToBuffer( 89 tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData))); 90 } 91 ThrowIfError(int resultCode)92 private static void ThrowIfError(int resultCode) { 93 if (resultCode != 0) throw new Exception("TensorFlowLite operation failed."); 94 } 95 96 #region Externs 97 98 [DllImport (TensorFlowLibrary)] TFL_NewModel(IntPtr model_data, int model_size)99 private static extern unsafe TFL_Interpreter TFL_NewModel(IntPtr model_data, int model_size); 100 101 [DllImport (TensorFlowLibrary)] TFL_DeleteModel(TFL_Model model)102 private static extern unsafe TFL_Interpreter TFL_DeleteModel(TFL_Model model); 103 104 [DllImport (TensorFlowLibrary)] TFL_NewInterpreter( TFL_Model model, TFL_InterpreterOptions optional_options)105 private static extern unsafe TFL_Interpreter TFL_NewInterpreter( 106 TFL_Model model, 107 TFL_InterpreterOptions optional_options); 108 109 [DllImport (TensorFlowLibrary)] TFL_DeleteInterpreter(TFL_Interpreter interpreter)110 private static extern unsafe void TFL_DeleteInterpreter(TFL_Interpreter interpreter); 111 112 [DllImport (TensorFlowLibrary)] TFL_InterpreterGetInputTensorCount( TFL_Interpreter interpreter)113 private static extern unsafe int TFL_InterpreterGetInputTensorCount( 114 TFL_Interpreter interpreter); 115 116 [DllImport (TensorFlowLibrary)] TFL_InterpreterGetInputTensor( TFL_Interpreter interpreter, int input_index)117 private static extern unsafe TFL_Tensor TFL_InterpreterGetInputTensor( 118 TFL_Interpreter interpreter, 119 int input_index); 120 121 [DllImport (TensorFlowLibrary)] TFL_InterpreterResizeInputTensor( TFL_Interpreter interpreter, int input_index, int[] input_dims, int input_dims_size)122 private static extern unsafe int TFL_InterpreterResizeInputTensor( 123 TFL_Interpreter interpreter, 124 int input_index, 125 int[] input_dims, 126 int input_dims_size); 127 128 [DllImport (TensorFlowLibrary)] TFL_InterpreterAllocateTensors( TFL_Interpreter interpreter)129 private static extern unsafe int TFL_InterpreterAllocateTensors( 130 TFL_Interpreter interpreter); 131 132 [DllImport (TensorFlowLibrary)] TFL_InterpreterInvoke(TFL_Interpreter interpreter)133 private static extern unsafe int TFL_InterpreterInvoke(TFL_Interpreter interpreter); 134 135 [DllImport (TensorFlowLibrary)] TFL_InterpreterGetOutputTensorCount( TFL_Interpreter interpreter)136 private static extern unsafe int TFL_InterpreterGetOutputTensorCount( 137 TFL_Interpreter interpreter); 138 139 [DllImport (TensorFlowLibrary)] TFL_InterpreterGetOutputTensor( TFL_Interpreter interpreter, int output_index)140 private static extern unsafe TFL_Tensor TFL_InterpreterGetOutputTensor( 141 TFL_Interpreter interpreter, 142 int output_index); 143 144 [DllImport (TensorFlowLibrary)] TFL_TensorCopyFromBuffer( TFL_Tensor tensor, IntPtr input_data, int input_data_size)145 private static extern unsafe int TFL_TensorCopyFromBuffer( 146 TFL_Tensor tensor, 147 IntPtr input_data, 148 int input_data_size); 149 150 [DllImport (TensorFlowLibrary)] TFL_TensorCopyToBuffer( TFL_Tensor tensor, IntPtr output_data, int output_data_size)151 private static extern unsafe int TFL_TensorCopyToBuffer( 152 TFL_Tensor tensor, 153 IntPtr output_data, 154 int output_data_size); 155 156 #endregion 157 } 158 } 159