• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.lite.gpu;
17 
18 import java.lang.reflect.Constructor;
19 import java.lang.reflect.InvocationTargetException;
20 import org.tensorflow.lite.Delegate;
21 import org.tensorflow.lite.DelegateFactory;
22 import org.tensorflow.lite.RuntimeFlavor;
23 
24 /** {@link DelegateFactory} for creating a {@link GpuDelegate}. */
25 public class GpuDelegateFactory implements DelegateFactory {
26 
27   private static final String GPU_DELEGATE_CLASS_NAME = "GpuDelegate";
28 
29   private final Options options;
30 
31   /** Delegate options. */
32   public static class Options {
Options()33     public Options() {}
34 
35     /**
36      * Delegate will be used only once, therefore, bootstrap/init time should be taken into account.
37      */
38     public static final int INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER = 0;
39 
40     /**
41      * Prefer maximizing the throughput. Same delegate will be used repeatedly on multiple inputs.
42      */
43     public static final int INFERENCE_PREFERENCE_SUSTAINED_SPEED = 1;
44 
45     /**
46      * Sets whether precision loss is allowed.
47      *
48      * @param precisionLossAllowed When `true` (default), the GPU may quantify tensors, downcast
49      *     values, process in FP16. When `false`, computations are carried out in 32-bit floating
50      *     point.
51      */
setPrecisionLossAllowed(boolean precisionLossAllowed)52     public Options setPrecisionLossAllowed(boolean precisionLossAllowed) {
53       this.precisionLossAllowed = precisionLossAllowed;
54       return this;
55     }
56 
57     /**
58      * Enables running quantized models with the delegate.
59      *
60      * <p>WARNING: This is an experimental API and subject to change.
61      *
62      * @param quantizedModelsAllowed When {@code true} (default), the GPU may run quantized models.
63      */
setQuantizedModelsAllowed(boolean quantizedModelsAllowed)64     public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) {
65       this.quantizedModelsAllowed = quantizedModelsAllowed;
66       return this;
67     }
68 
69     /**
70      * Sets the inference preference for precision/compilation/runtime tradeoffs.
71      *
72      * @param preference One of `INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER` (default),
73      *     `INFERENCE_PREFERENCE_SUSTAINED_SPEED`.
74      */
setInferencePreference(int preference)75     public Options setInferencePreference(int preference) {
76       this.inferencePreference = preference;
77       return this;
78     }
79 
80     /**
81      * Enables serialization on the delegate. Note non-null {@code serializationDir} and {@code
82      * modelToken} are required for serialization.
83      *
84      * <p>WARNING: This is an experimental API and subject to change.
85      *
86      * @param serializationDir The directory to use for storing data. Caller is responsible to
87      *     ensure the model is not stored in a public directory. It's recommended to use {@link
88      *     android.content.Context#getCodeCacheDir()} to provide a private location for the
89      *     application on Android.
90      * @param modelToken The token to be used to identify the model. Caller is responsible to ensure
91      *     the token is unique to the model graph and data.
92      */
setSerializationParams(String serializationDir, String modelToken)93     public Options setSerializationParams(String serializationDir, String modelToken) {
94       this.serializationDir = serializationDir;
95       this.modelToken = modelToken;
96       return this;
97     }
98 
isPrecisionLossAllowed()99     public boolean isPrecisionLossAllowed() {
100       return precisionLossAllowed;
101     }
102 
areQuantizedModelsAllowed()103     public boolean areQuantizedModelsAllowed() {
104       return quantizedModelsAllowed;
105     }
106 
getInferencePreference()107     public int getInferencePreference() {
108       return inferencePreference;
109     }
110 
getSerializationDir()111     public String getSerializationDir() {
112       return serializationDir;
113     }
114 
getModelToken()115     public String getModelToken() {
116       return modelToken;
117     }
118 
119     private boolean precisionLossAllowed = true;
120     boolean quantizedModelsAllowed = true;
121     int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
122     String serializationDir = null;
123     String modelToken = null;
124   }
125 
GpuDelegateFactory()126   public GpuDelegateFactory() {
127     this(new Options());
128   }
129 
GpuDelegateFactory(Options options)130   public GpuDelegateFactory(Options options) {
131     this.options = options;
132   }
133 
134   @Override
create(RuntimeFlavor runtimeFlavor)135   public Delegate create(RuntimeFlavor runtimeFlavor) {
136     String packageName;
137     switch (runtimeFlavor) {
138       case APPLICATION:
139         packageName = "org.tensorflow.lite.gpu";
140         break;
141       case SYSTEM:
142         packageName = "com.google.android.gms.tflite.gpu";
143         break;
144       default:
145         throw new IllegalArgumentException("Unsupported runtime flavor " + runtimeFlavor);
146     }
147     try {
148       Class<?> delegateClass = Class.forName(packageName + "." + GPU_DELEGATE_CLASS_NAME);
149       Constructor<?> constructor = delegateClass.getDeclaredConstructor(Options.class);
150       return (Delegate) constructor.newInstance(options);
151     } catch (ClassNotFoundException
152         | IllegalAccessException
153         | InstantiationException
154         | NoSuchMethodException
155         | InvocationTargetException e) {
156       throw new IllegalStateException("Error creating GPU delegate", e);
157     }
158   }
159 }
160