• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite.task.core;
17 
18 import android.content.Context;
19 import android.content.res.AssetFileDescriptor;
20 import android.util.Log;
21 import java.io.FileInputStream;
22 import java.io.IOException;
23 import java.nio.ByteBuffer;
24 import java.nio.MappedByteBuffer;
25 import java.nio.channels.FileChannel;
26 
27 /** JNI utils for Task API. */
28 public class TaskJniUtils {
29   public static final long INVALID_POINTER = 0;
30   private static final String TAG = TaskJniUtils.class.getSimpleName();
31   /** Syntax sugar to get nativeHandle from empty param list. */
32   public interface EmptyHandleProvider {
createHandle()33     long createHandle();
34   }
35 
36   /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */
37   public interface MultipleBuffersHandleProvider {
createHandle(ByteBuffer... buffers)38     long createHandle(ByteBuffer... buffers);
39   }
40 
41   /** Syntax sugar to get nativeHandle from file descriptor and options. */
42   public interface FdAndOptionsHandleProvider<T> {
createHandle( int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options)43     long createHandle(
44         int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options);
45   }
46 
47   /**
48    * Initializes the JNI and returns C++ handle with file descriptor and options for task API.
49    *
50    * @param context the Android app context
51    * @param provider provider to get C++ handle, usually returned from native call
52    * @param libName name of C++ lib to be loaded
53    * @param filePath path of the file to be loaded
54    * @param options options to set up the task API, used by the provider
55    * @return C++ handle as long
56    * @throws IOException If model file fails to load.
57    */
createHandleFromFdAndOptions( Context context, final FdAndOptionsHandleProvider<T> provider, String libName, String filePath, final T options)58   public static <T> long createHandleFromFdAndOptions(
59       Context context,
60       final FdAndOptionsHandleProvider<T> provider,
61       String libName,
62       String filePath,
63       final T options)
64       throws IOException {
65     try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) {
66       return createHandleFromLibrary(
67           new EmptyHandleProvider() {
68             @Override
69             public long createHandle() {
70               return provider.createHandle(
71                   /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
72                   /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
73                   /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
74                   options);
75             }
76           },
77           libName);
78     }
79   }
80 
81   /**
82    * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
83    * {@link EmptyHandleProvider#createHandle()}.
84    *
85    * @param provider provider to get C++ handle, usually returned from native call
86    * @return C++ handle as long
87    */
88   public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) {
89     tryLoadLibrary(libName);
90     try {
91       return provider.createHandle();
92     } catch (Exception e) {
93       String errorMessage = "Error getting native address of native library: " + libName;
94       Log.e(TAG, errorMessage, e);
95       throw new IllegalStateException(errorMessage, e);
96     }
97   }
98 
99   /**
100    * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
101    * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}.
102    *
103    * @param context app context
104    * @param provider provider to get C++ pointer, usually returned from native call
105    * @param libName name of C++ lib to load
106    * @param filePaths file paths to load
107    * @return C++ pointer as long
108    * @throws IOException If model file fails to load.
109    */
110   public static long createHandleWithMultipleAssetFilesFromLibrary(
111       Context context,
112       final MultipleBuffersHandleProvider provider,
113       String libName,
114       String... filePaths)
115       throws IOException {
116     final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length];
117     for (int i = 0; i < filePaths.length; i++) {
118       buffers[i] = loadMappedFile(context, filePaths[i]);
119     }
120     return createHandleFromLibrary(
121         new EmptyHandleProvider() {
122           @Override
123           public long createHandle() {
124             return provider.createHandle(buffers);
125           }
126         },
127         libName);
128   }
129 
130   /**
131    * Loads a file from the asset folder through memory mapping.
132    *
133    * @param context Application context to access assets.
134    * @param filePath Asset path of the file.
135    * @return the loaded memory mapped file.
136    * @throws IOException If model file fails to load.
137    */
138   public static MappedByteBuffer loadMappedFile(Context context, String filePath)
139       throws IOException {
140     try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
141         FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
142       FileChannel fileChannel = inputStream.getChannel();
143       long startOffset = fileDescriptor.getStartOffset();
144       long declaredLength = fileDescriptor.getDeclaredLength();
145       return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
146     }
147   }
148 
149   private TaskJniUtils() {}
150 
151   /**
152    * Try load a native library, if it's already loaded return directly.
153    *
154    * @param libName name of the lib
155    */
156   static void tryLoadLibrary(String libName) {
157     try {
158       System.loadLibrary(libName);
159     } catch (UnsatisfiedLinkError e) {
160       String errorMessage = "Error loading native library: " + libName;
161       Log.e(TAG, errorMessage, e);
162       throw new UnsatisfiedLinkError(errorMessage);
163     }
164   }
165 }
166