1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite.task.core; 17 18 import android.content.Context; 19 import android.content.res.AssetFileDescriptor; 20 import android.util.Log; 21 import java.io.FileInputStream; 22 import java.io.IOException; 23 import java.nio.ByteBuffer; 24 import java.nio.MappedByteBuffer; 25 import java.nio.channels.FileChannel; 26 27 /** JNI utils for Task API. */ 28 public class TaskJniUtils { 29 public static final long INVALID_POINTER = 0; 30 private static final String TAG = TaskJniUtils.class.getSimpleName(); 31 /** Syntax sugar to get nativeHandle from empty param list. */ 32 public interface EmptyHandleProvider { createHandle()33 long createHandle(); 34 } 35 36 /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */ 37 public interface MultipleBuffersHandleProvider { createHandle(ByteBuffer... buffers)38 long createHandle(ByteBuffer... buffers); 39 } 40 41 /** Syntax sugar to get nativeHandle from file descriptor and options. */ 42 public interface FdAndOptionsHandleProvider<T> { createHandle( int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options)43 long createHandle( 44 int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options); 45 } 46 47 /** 48 * Initializes the JNI and returns C++ handle with file descriptor and options for task API. 49 * 50 * @param context the Android app context 51 * @param provider provider to get C++ handle, usually returned from native call 52 * @param libName name of C++ lib to be loaded 53 * @param filePath path of the file to be loaded 54 * @param options options to set up the task API, used by the provider 55 * @return C++ handle as long 56 * @throws IOException If model file fails to load. 57 */ createHandleFromFdAndOptions( Context context, final FdAndOptionsHandleProvider<T> provider, String libName, String filePath, final T options)58 public static <T> long createHandleFromFdAndOptions( 59 Context context, 60 final FdAndOptionsHandleProvider<T> provider, 61 String libName, 62 String filePath, 63 final T options) 64 throws IOException { 65 try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) { 66 return createHandleFromLibrary( 67 new EmptyHandleProvider() { 68 @Override 69 public long createHandle() { 70 return provider.createHandle( 71 /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), 72 /*fileDescriptorLength=*/ assetFileDescriptor.getLength(), 73 /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), 74 options); 75 } 76 }, 77 libName); 78 } 79 } 80 81 /** 82 * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes 83 * {@link EmptyHandleProvider#createHandle()}. 84 * 85 * @param provider provider to get C++ handle, usually returned from native call 86 * @return C++ handle as long 87 */ 88 public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) { 89 tryLoadLibrary(libName); 90 try { 91 return provider.createHandle(); 92 } catch (Exception e) { 93 String errorMessage = "Error getting native address of native library: " + libName; 94 Log.e(TAG, errorMessage, e); 95 throw new IllegalStateException(errorMessage, e); 96 } 97 } 98 99 /** 100 * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes 101 * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}. 102 * 103 * @param context app context 104 * @param provider provider to get C++ pointer, usually returned from native call 105 * @param libName name of C++ lib to load 106 * @param filePaths file paths to load 107 * @return C++ pointer as long 108 * @throws IOException If model file fails to load. 109 */ 110 public static long createHandleWithMultipleAssetFilesFromLibrary( 111 Context context, 112 final MultipleBuffersHandleProvider provider, 113 String libName, 114 String... filePaths) 115 throws IOException { 116 final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length]; 117 for (int i = 0; i < filePaths.length; i++) { 118 buffers[i] = loadMappedFile(context, filePaths[i]); 119 } 120 return createHandleFromLibrary( 121 new EmptyHandleProvider() { 122 @Override 123 public long createHandle() { 124 return provider.createHandle(buffers); 125 } 126 }, 127 libName); 128 } 129 130 /** 131 * Loads a file from the asset folder through memory mapping. 132 * 133 * @param context Application context to access assets. 134 * @param filePath Asset path of the file. 135 * @return the loaded memory mapped file. 136 * @throws IOException If model file fails to load. 137 */ 138 public static MappedByteBuffer loadMappedFile(Context context, String filePath) 139 throws IOException { 140 try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath); 141 FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) { 142 FileChannel fileChannel = inputStream.getChannel(); 143 long startOffset = fileDescriptor.getStartOffset(); 144 long declaredLength = fileDescriptor.getDeclaredLength(); 145 return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); 146 } 147 } 148 149 private TaskJniUtils() {} 150 151 /** 152 * Try load a native library, if it's already loaded return directly. 153 * 154 * @param libName name of the lib 155 */ 156 static void tryLoadLibrary(String libName) { 157 try { 158 System.loadLibrary(libName); 159 } catch (UnsatisfiedLinkError e) { 160 String errorMessage = "Error loading native library: " + libName; 161 Log.e(TAG, errorMessage, e); 162 throw new UnsatisfiedLinkError(errorMessage); 163 } 164 } 165 } 166