• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.ondevicepersonalization;
18 
19 import static com.android.ondevicepersonalization.internal.util.ByteArrayUtil.deserializeObject;
20 
21 import static com.google.common.truth.Truth.assertThat;
22 
23 import static junit.framework.Assert.assertEquals;
24 
25 import static org.junit.Assert.assertThrows;
26 
27 import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
28 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback;
29 import android.os.Bundle;
30 
31 import com.android.ondevicepersonalization.internal.util.ByteArrayUtil;
32 
33 import org.junit.Before;
34 import org.junit.Test;
35 
36 import java.io.NotSerializableException;
37 import java.util.HashMap;
38 
39 public class InferenceInputTest {
40     private static final String MODEL_KEY = "model_key";
41     private RemoteDataImpl mRemoteData;
42 
43     @Before
setup()44     public void setup() {
45         mRemoteData =
46                 new RemoteDataImpl(
47                         IDataAccessService.Stub.asInterface(new TestDataAccessService()));
48     }
49 
50     @Test
buildParams_reusable()51     public void buildParams_reusable() {
52         InferenceInput.Params.Builder builder =
53                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY);
54         builder.build();
55 
56         InferenceInput.Params params = builder.setModelKey("other_kay").build();
57 
58         assertThat(params.getModelKey()).isEqualTo("other_kay");
59     }
60 
61     @Test
buildInferenceInput_reusable()62     public void buildInferenceInput_reusable() {
63         HashMap<Integer, Object> outputData = new HashMap<>();
64         outputData.put(0, new float[1]);
65         Object[] input = new Object[1];
66         input[0] = new float[] {1.2f};
67         InferenceInput.Params params =
68                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
69 
70         InferenceInput.Builder builder =
71                 new InferenceInput.Builder(
72                                 params,
73                                 input,
74                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
75                         .setBatchSize(1);
76         builder.build();
77         InferenceInput inferenceInput = builder.setBatchSize(10).build();
78         assertThat(inferenceInput.getBatchSize()).isEqualTo(10);
79     }
80 
81     @Test
buildInput_success()82     public void buildInput_success() {
83         HashMap<Integer, Object> outputData = new HashMap<>();
84         outputData.put(0, new float[1]);
85         Object[] input = new Object[1];
86         input[0] = new float[] {1.2f};
87         InferenceInput.Params params =
88                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
89 
90         InferenceInput inferenceInput =
91                 new InferenceInput.Builder(
92                                 params,
93                                 input,
94                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
95                         .setBatchSize(1)
96                         .build();
97 
98         float[] inputData = (float[]) inferenceInput.getInputData()[0];
99         assertEquals(inputData[0], 1.2f);
100         assertThat(inferenceInput.getBatchSize()).isEqualTo(1);
101         assertThat(inferenceInput.getExpectedOutputStructure().getDataOutputs()).hasSize(1);
102         assertThat(inferenceInput.getParams()).isEqualTo(params);
103     }
104 
105     @Test
buildInput_batchNotSet_success()106     public void buildInput_batchNotSet_success() {
107         HashMap<Integer, Object> outputData = new HashMap<>();
108         outputData.put(0, new float[1]);
109         Object[] input = new Object[1];
110         input[0] = new float[] {1.2f};
111         InferenceInput.Params params =
112                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
113 
114         InferenceInput inferenceInput =
115                 new InferenceInput.Builder(
116                                 params,
117                                 input,
118                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
119                         .build();
120 
121         assertThat(inferenceInput.getBatchSize()).isEqualTo(1);
122     }
123 
124     @Test
buildParams_success()125     public void buildParams_success() {
126         InferenceInput.Params params =
127                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
128 
129         assertThat(params.getRecommendedNumThreads()).isEqualTo(1);
130         assertThat(params.getDelegateType()).isEqualTo(InferenceInput.Params.DELEGATE_CPU);
131         assertThat(params.getModelType())
132                 .isEqualTo(InferenceInput.Params.MODEL_TYPE_TENSORFLOW_LITE);
133         assertThat(params.getKeyValueStore()).isEqualTo(mRemoteData);
134         assertThat(params.getModelKey()).isEqualTo(MODEL_KEY);
135     }
136 
137     @Test
buildParams_negativeThread_throws()138     public void buildParams_negativeThread_throws() {
139         assertThrows(
140                 IllegalStateException.class,
141                 () ->
142                         new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY)
143                                 .setRecommendedNumThreads(-2)
144                                 .build());
145     }
146 
147     @Test
buildParams_nullModelKey_throws()148     public void buildParams_nullModelKey_throws() {
149         assertThrows(
150                 NullPointerException.class,
151                 () -> new InferenceInput.Params.Builder(mRemoteData, null).build());
152     }
153 
154     @Test
buildLiteRT_success()155     public void buildLiteRT_success() {
156         HashMap<Integer, Object> outputData = new HashMap<>();
157         outputData.put(0, new float[1]);
158         Object[] input = new Object[1];
159         input[0] = new float[] {1.2f};
160         InferenceInput.Params params =
161                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
162 
163         InferenceInput result =
164                 new InferenceInput.Builder(params, ByteArrayUtil.serializeObject(input))
165                         .setExpectedOutputStructure(
166                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
167                         .build();
168 
169         Object[] obj = (Object[]) deserializeObject(result.getData());
170         assertThat(obj).isEqualTo(input);
171     }
172 
173     @Test
buildExecuTorch_success()174     public void buildExecuTorch_success() {
175         // TODO(b/376902350): update input with EValue.
176         byte[] input = {1, 2, 3};
177         InferenceInput.Params params =
178                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY)
179                         .setModelType(InferenceInput.Params.MODEL_TYPE_EXECUTORCH)
180                         .build();
181 
182         InferenceInput result = new InferenceInput.Builder(params, input).build();
183 
184         assertThat(result.getData()).isEqualTo(input);
185     }
186 
187     @Test
buildExecutorchInput_missingInputData()188     public void buildExecutorchInput_missingInputData() {
189         InferenceInput.Params params =
190                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY)
191                         .setModelType(InferenceInput.Params.MODEL_TYPE_EXECUTORCH)
192                         .build();
193 
194         assertThrows(
195                 IllegalArgumentException.class,
196                 () -> new InferenceInput.Builder(params, new byte[] {}).build());
197     }
198 
199     @Test
buildLiteRTInput_missingInputData()200     public void buildLiteRTInput_missingInputData() {
201         HashMap<Integer, Object> outputData = new HashMap<>();
202         outputData.put(0, new float[1]);
203         InferenceInput.Params params =
204                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
205 
206         assertThrows(
207                 IllegalArgumentException.class,
208                 () ->
209                         new InferenceInput.Builder(params, new byte[] {})
210                                 .setExpectedOutputStructure(
211                                         new InferenceOutput.Builder()
212                                                 .setDataOutputs(outputData)
213                                                 .build())
214                                 .build());
215     }
216 
217     @Test
buildLiteRTInput_missingOutputStructure()218     public void buildLiteRTInput_missingOutputStructure() {
219         InferenceInput.Params params =
220                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
221 
222         assertThrows(
223                 IllegalArgumentException.class,
224                 () ->
225                         new InferenceInput.Builder(params, new byte[] {})
226                                 .setExpectedOutputStructure(
227                                         new InferenceOutput.Builder()
228                                                 .setDataOutputs(new HashMap<>())
229                                                 .build())
230                                 .build());
231     }
232 
233     @Test
nonSerializable()234     public void nonSerializable() {
235         NonSerializableObject obj = new NonSerializableObject(123);
236         InferenceInput.Params params =
237                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
238 
239         IllegalArgumentException exp =
240                 assertThrows(
241                         IllegalArgumentException.class,
242                         () ->
243                                 new InferenceInput.Builder(
244                                                 params,
245                                                 new Object[] {obj},
246                                                 new InferenceOutput.Builder()
247                                                         .setDataOutputs(new HashMap<>())
248                                                         .build())
249                                         .build());
250 
251         assertThat(exp.getCause()).isInstanceOf(NotSerializableException.class);
252     }
253 
254     /** A class used for serializable exception test. */
255     class NonSerializableObject {
256         private final int mData;
257 
NonSerializableObject(int data)258         NonSerializableObject(int data) {
259             this.mData = data;
260         }
261     }
262 
263     static class TestDataAccessService extends IDataAccessService.Stub {
264         @Override
onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)265         public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {}
266 
267         @Override
logApiCallStats(int apiName, long latencyMillis, int responseCode)268         public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {}
269     }
270 }
271