1 /* 2 * Copyright 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.adservices.ondevicepersonalization; 18 19 import static com.android.ondevicepersonalization.internal.util.ByteArrayUtil.deserializeObject; 20 21 import static com.google.common.truth.Truth.assertThat; 22 23 import static junit.framework.Assert.assertEquals; 24 25 import static org.junit.Assert.assertThrows; 26 27 import android.adservices.ondevicepersonalization.aidl.IDataAccessService; 28 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; 29 import android.os.Bundle; 30 31 import com.android.ondevicepersonalization.internal.util.ByteArrayUtil; 32 33 import org.junit.Before; 34 import org.junit.Test; 35 36 import java.io.NotSerializableException; 37 import java.util.HashMap; 38 39 public class InferenceInputTest { 40 private static final String MODEL_KEY = "model_key"; 41 private RemoteDataImpl mRemoteData; 42 43 @Before setup()44 public void setup() { 45 mRemoteData = 46 new RemoteDataImpl( 47 IDataAccessService.Stub.asInterface(new TestDataAccessService())); 48 } 49 50 @Test buildParams_reusable()51 public void buildParams_reusable() { 52 InferenceInput.Params.Builder builder = 53 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY); 54 builder.build(); 55 56 InferenceInput.Params params = builder.setModelKey("other_kay").build(); 57 58 assertThat(params.getModelKey()).isEqualTo("other_kay"); 59 } 60 61 @Test buildInferenceInput_reusable()62 public void buildInferenceInput_reusable() { 63 HashMap<Integer, Object> outputData = new HashMap<>(); 64 outputData.put(0, new float[1]); 65 Object[] input = new Object[1]; 66 input[0] = new float[] {1.2f}; 67 InferenceInput.Params params = 68 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 69 70 InferenceInput.Builder builder = 71 new InferenceInput.Builder( 72 params, 73 input, 74 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 75 .setBatchSize(1); 76 builder.build(); 77 InferenceInput inferenceInput = builder.setBatchSize(10).build(); 78 assertThat(inferenceInput.getBatchSize()).isEqualTo(10); 79 } 80 81 @Test buildInput_success()82 public void buildInput_success() { 83 HashMap<Integer, Object> outputData = new HashMap<>(); 84 outputData.put(0, new float[1]); 85 Object[] input = new Object[1]; 86 input[0] = new float[] {1.2f}; 87 InferenceInput.Params params = 88 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 89 90 InferenceInput inferenceInput = 91 new InferenceInput.Builder( 92 params, 93 input, 94 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 95 .setBatchSize(1) 96 .build(); 97 98 float[] inputData = (float[]) inferenceInput.getInputData()[0]; 99 assertEquals(inputData[0], 1.2f); 100 assertThat(inferenceInput.getBatchSize()).isEqualTo(1); 101 assertThat(inferenceInput.getExpectedOutputStructure().getDataOutputs()).hasSize(1); 102 assertThat(inferenceInput.getParams()).isEqualTo(params); 103 } 104 105 @Test buildInput_batchNotSet_success()106 public void buildInput_batchNotSet_success() { 107 HashMap<Integer, Object> outputData = new HashMap<>(); 108 outputData.put(0, new float[1]); 109 Object[] input = new Object[1]; 110 input[0] = new float[] {1.2f}; 111 InferenceInput.Params params = 112 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 113 114 InferenceInput inferenceInput = 115 new InferenceInput.Builder( 116 params, 117 input, 118 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 119 .build(); 120 121 assertThat(inferenceInput.getBatchSize()).isEqualTo(1); 122 } 123 124 @Test buildParams_success()125 public void buildParams_success() { 126 InferenceInput.Params params = 127 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 128 129 assertThat(params.getRecommendedNumThreads()).isEqualTo(1); 130 assertThat(params.getDelegateType()).isEqualTo(InferenceInput.Params.DELEGATE_CPU); 131 assertThat(params.getModelType()) 132 .isEqualTo(InferenceInput.Params.MODEL_TYPE_TENSORFLOW_LITE); 133 assertThat(params.getKeyValueStore()).isEqualTo(mRemoteData); 134 assertThat(params.getModelKey()).isEqualTo(MODEL_KEY); 135 } 136 137 @Test buildParams_negativeThread_throws()138 public void buildParams_negativeThread_throws() { 139 assertThrows( 140 IllegalStateException.class, 141 () -> 142 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY) 143 .setRecommendedNumThreads(-2) 144 .build()); 145 } 146 147 @Test buildParams_nullModelKey_throws()148 public void buildParams_nullModelKey_throws() { 149 assertThrows( 150 NullPointerException.class, 151 () -> new InferenceInput.Params.Builder(mRemoteData, null).build()); 152 } 153 154 @Test buildLiteRT_success()155 public void buildLiteRT_success() { 156 HashMap<Integer, Object> outputData = new HashMap<>(); 157 outputData.put(0, new float[1]); 158 Object[] input = new Object[1]; 159 input[0] = new float[] {1.2f}; 160 InferenceInput.Params params = 161 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 162 163 InferenceInput result = 164 new InferenceInput.Builder(params, ByteArrayUtil.serializeObject(input)) 165 .setExpectedOutputStructure( 166 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 167 .build(); 168 169 Object[] obj = (Object[]) deserializeObject(result.getData()); 170 assertThat(obj).isEqualTo(input); 171 } 172 173 @Test buildExecuTorch_success()174 public void buildExecuTorch_success() { 175 // TODO(b/376902350): update input with EValue. 176 byte[] input = {1, 2, 3}; 177 InferenceInput.Params params = 178 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY) 179 .setModelType(InferenceInput.Params.MODEL_TYPE_EXECUTORCH) 180 .build(); 181 182 InferenceInput result = new InferenceInput.Builder(params, input).build(); 183 184 assertThat(result.getData()).isEqualTo(input); 185 } 186 187 @Test buildExecutorchInput_missingInputData()188 public void buildExecutorchInput_missingInputData() { 189 InferenceInput.Params params = 190 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY) 191 .setModelType(InferenceInput.Params.MODEL_TYPE_EXECUTORCH) 192 .build(); 193 194 assertThrows( 195 IllegalArgumentException.class, 196 () -> new InferenceInput.Builder(params, new byte[] {}).build()); 197 } 198 199 @Test buildLiteRTInput_missingInputData()200 public void buildLiteRTInput_missingInputData() { 201 HashMap<Integer, Object> outputData = new HashMap<>(); 202 outputData.put(0, new float[1]); 203 InferenceInput.Params params = 204 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 205 206 assertThrows( 207 IllegalArgumentException.class, 208 () -> 209 new InferenceInput.Builder(params, new byte[] {}) 210 .setExpectedOutputStructure( 211 new InferenceOutput.Builder() 212 .setDataOutputs(outputData) 213 .build()) 214 .build()); 215 } 216 217 @Test buildLiteRTInput_missingOutputStructure()218 public void buildLiteRTInput_missingOutputStructure() { 219 InferenceInput.Params params = 220 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 221 222 assertThrows( 223 IllegalArgumentException.class, 224 () -> 225 new InferenceInput.Builder(params, new byte[] {}) 226 .setExpectedOutputStructure( 227 new InferenceOutput.Builder() 228 .setDataOutputs(new HashMap<>()) 229 .build()) 230 .build()); 231 } 232 233 @Test nonSerializable()234 public void nonSerializable() { 235 NonSerializableObject obj = new NonSerializableObject(123); 236 InferenceInput.Params params = 237 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 238 239 IllegalArgumentException exp = 240 assertThrows( 241 IllegalArgumentException.class, 242 () -> 243 new InferenceInput.Builder( 244 params, 245 new Object[] {obj}, 246 new InferenceOutput.Builder() 247 .setDataOutputs(new HashMap<>()) 248 .build()) 249 .build()); 250 251 assertThat(exp.getCause()).isInstanceOf(NotSerializableException.class); 252 } 253 254 /** A class used for serializable exception test. */ 255 class NonSerializableObject { 256 private final int mData; 257 NonSerializableObject(int data)258 NonSerializableObject(int data) { 259 this.mData = data; 260 } 261 } 262 263 static class TestDataAccessService extends IDataAccessService.Stub { 264 @Override onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)265 public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {} 266 267 @Override logApiCallStats(int apiName, long latencyMillis, int responseCode)268 public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {} 269 } 270 } 271