/* * Copyright 2024 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package android.adservices.ondevicepersonalization; import static com.google.common.truth.Truth.assertThat; import static junit.framework.Assert.assertEquals; import static org.junit.Assert.assertThrows; import android.adservices.ondevicepersonalization.aidl.IDataAccessService; import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; import android.os.Bundle; import org.junit.Before; import org.junit.Test; import java.util.HashMap; public class InferenceInputTest { private static final String MODEL_KEY = "model_key"; private RemoteDataImpl mRemoteData; @Before public void setup() { mRemoteData = new RemoteDataImpl( IDataAccessService.Stub.asInterface(new TestDataAccessService())); } @Test public void buildParams_reusable() { InferenceInput.Params.Builder builder = new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY); builder.build(); InferenceInput.Params params = builder.setModelKey("other_kay").build(); assertThat(params.getModelKey()).isEqualTo("other_kay"); } @Test public void buildInferenceInput_reusable() { HashMap outputData = new HashMap<>(); outputData.put(0, new float[1]); Object[] input = new Object[1]; input[0] = new float[] {1.2f}; InferenceInput.Params params = new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); InferenceInput.Builder builder = new InferenceInput.Builder( params, input, new InferenceOutput.Builder().setDataOutputs(outputData).build()) .setBatchSize(1); builder.build(); InferenceInput inferenceInput = builder.setBatchSize(10).build(); assertThat(inferenceInput.getBatchSize()).isEqualTo(10); } @Test public void buildInput_success() { HashMap outputData = new HashMap<>(); outputData.put(0, new float[1]); Object[] input = new Object[1]; input[0] = new float[] {1.2f}; InferenceInput.Params params = new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); InferenceInput inferenceInput = new InferenceInput.Builder( params, input, new InferenceOutput.Builder().setDataOutputs(outputData).build()) .setBatchSize(1) .build(); float[] inputData = (float[]) inferenceInput.getInputData()[0]; assertEquals(inputData[0], 1.2f); assertThat(inferenceInput.getBatchSize()).isEqualTo(1); assertThat(inferenceInput.getExpectedOutputStructure().getDataOutputs()).hasSize(1); assertThat(inferenceInput.getParams()).isEqualTo(params); } @Test public void buildInput_batchNotSet_success() { HashMap outputData = new HashMap<>(); outputData.put(0, new float[1]); Object[] input = new Object[1]; input[0] = new float[] {1.2f}; InferenceInput.Params params = new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); InferenceInput inferenceInput = new InferenceInput.Builder( params, input, new InferenceOutput.Builder().setDataOutputs(outputData).build()) .build(); assertThat(inferenceInput.getBatchSize()).isEqualTo(1); } @Test public void buildParams_success() { InferenceInput.Params params = new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); assertThat(params.getRecommendedNumThreads()).isEqualTo(1); assertThat(params.getDelegateType()).isEqualTo(InferenceInput.Params.DELEGATE_CPU); assertThat(params.getModelType()) .isEqualTo(InferenceInput.Params.MODEL_TYPE_TENSORFLOW_LITE); assertThat(params.getKeyValueStore()).isEqualTo(mRemoteData); assertThat(params.getModelKey()).isEqualTo(MODEL_KEY); } @Test public void buildParams_negativeThread_throws() { assertThrows( IllegalStateException.class, () -> new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY) .setRecommendedNumThreads(-2) .build()); } @Test public void buildParams_nullModelKey_throws() { assertThrows( NullPointerException.class, () -> new InferenceInput.Params.Builder(mRemoteData, null).build()); } static class TestDataAccessService extends IDataAccessService.Stub { @Override public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {} @Override public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {} } }