1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.providers.media.mediacognitionservices; 18 19 import static android.provider.MediaCognitionService.ProcessingTypes; 20 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertNotNull; 23 import static org.junit.Assert.assertTrue; 24 25 import android.content.ComponentName; 26 import android.content.Context; 27 import android.content.Intent; 28 import android.content.ServiceConnection; 29 import android.database.CursorWindow; 30 import android.net.Uri; 31 import android.os.IBinder; 32 import android.os.RemoteException; 33 import android.platform.test.annotations.RequiresFlagsEnabled; 34 import android.platform.test.flag.junit.CheckFlagsRule; 35 import android.platform.test.flag.junit.DeviceFlagsValueProvider; 36 import android.provider.MediaCognitionProcessingRequest; 37 import android.provider.MediaCognitionProcessingVersions; 38 import android.provider.MediaCognitionService; 39 import android.provider.mediacognitionutils.ICognitionGetVersionsCallbackInternal; 40 import android.provider.mediacognitionutils.ICognitionProcessMediaCallbackInternal; 41 import android.provider.mediacognitionutils.IMediaCognitionService; 42 43 import androidx.test.platform.app.InstrumentationRegistry; 44 import androidx.test.runner.AndroidJUnit4; 45 46 import com.android.providers.media.flags.Flags; 47 48 import org.junit.After; 49 import org.junit.Before; 50 import org.junit.Rule; 51 import org.junit.Test; 52 import org.junit.runner.RunWith; 53 54 import java.util.ArrayList; 55 import java.util.List; 56 import java.util.concurrent.CountDownLatch; 57 import java.util.concurrent.TimeUnit; 58 59 60 @RequiresFlagsEnabled(Flags.FLAG_MEDIA_COGNITION_SERVICE) 61 @RunWith(AndroidJUnit4.class) 62 public class MediaCognitionServiceTest { 63 64 @Rule 65 public final CheckFlagsRule mCheckFlagsRule = DeviceFlagsValueProvider.createCheckFlagsRule(); 66 67 private CountDownLatch mServiceLatch = new CountDownLatch(1); 68 private IMediaCognitionService mPrimaryService; 69 private Context mContext; 70 71 @Before setUp()72 public void setUp() throws Exception { 73 mContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); 74 Intent intent = new Intent(MediaCognitionService.SERVICE_INTERFACE); 75 intent.setClassName("com.android.providers.media.tests", 76 "com.android.providers.media.mediacognitionservices.TestMediaCognitionService"); 77 mContext.bindService(intent, mServiceConnection, Context.BIND_AUTO_CREATE); 78 mServiceLatch.await(3, TimeUnit.SECONDS); 79 } 80 81 private ServiceConnection mServiceConnection = new ServiceConnection() { 82 @Override 83 public void onServiceConnected(ComponentName componentName, IBinder iBinder) { 84 mPrimaryService = IMediaCognitionService.Stub.asInterface(iBinder); // Update interface 85 mServiceLatch.countDown(); 86 } 87 @Override 88 public void onServiceDisconnected(ComponentName componentName) { 89 mPrimaryService = null; 90 } 91 }; 92 93 @After tearDown()94 public void tearDown() throws Exception { 95 mContext.unbindService(mServiceConnection); 96 } 97 98 @Test testProcessMedia()99 public void testProcessMedia() throws Exception { 100 assertNotNull(mPrimaryService); 101 List<MediaCognitionProcessingRequest> requests = 102 new ArrayList<MediaCognitionProcessingRequest>(); 103 requests.add(new MediaCognitionProcessingRequest 104 .Builder(Uri.parse("content://media/test_image/1")) 105 .setProcessingCombination( 106 ProcessingTypes.IMAGE_OCR_LATIN | ProcessingTypes.IMAGE_LABEL) 107 .build()); 108 109 final TestProcessMediaCallback callback = new TestProcessMediaCallback(); 110 mPrimaryService.processMedia(requests, callback); 111 callback.await(3, TimeUnit.SECONDS); 112 final CursorWindow[] windows = callback.mWindows; 113 assertTrue(windows.length > 0); 114 // first column id should be 1 115 assertTrue(windows[0].getString(0, 0).equals("1")); 116 assertTrue(windows[0].getString(0, 1).equals("image_ocr_latin_1")); 117 assertTrue(windows[0].getString(0, 2).equals("image_label_1")); 118 windows[0].close(); 119 } 120 121 @Test testProcessMediaLargeData()122 public void testProcessMediaLargeData() throws Exception { 123 assertNotNull(mPrimaryService); 124 List<MediaCognitionProcessingRequest> requests = 125 new ArrayList<MediaCognitionProcessingRequest>(); 126 int totalCount = 100; 127 for (int count = 1; count <= totalCount; count++) { 128 requests.add( 129 new MediaCognitionProcessingRequest 130 .Builder(Uri.parse("content://media/test_image_large_data/" + count)) 131 .setProcessingCombination( 132 ProcessingTypes.IMAGE_OCR_LATIN | ProcessingTypes.IMAGE_LABEL) 133 .build()); 134 } 135 136 final TestProcessMediaCallback callback = new TestProcessMediaCallback(); 137 mPrimaryService.processMedia(requests, callback); 138 callback.await(3, TimeUnit.SECONDS); 139 final CursorWindow[] windows = callback.mWindows; 140 int count = 0; 141 assertTrue(windows.length > 0); 142 for (int index = 0; index < windows.length; index++) { 143 for (int row = 0; row < windows[index].getNumRows(); row++) { 144 count++; 145 // matching id 146 assertTrue(windows[index].getString(count - 1, 0).equals(String.valueOf(count))); 147 assertNotNull(windows[index].getString(count - 1, 1)); 148 assertNotNull(windows[index].getString(count - 1, 2)); 149 } 150 windows[index].close(); 151 } 152 // making sure got all results back 153 assertEquals(count, totalCount); 154 } 155 156 @Test testGetVersions()157 public void testGetVersions() throws Exception { 158 assertNotNull(mPrimaryService); 159 final TestGetVersionsCallback callback = new TestGetVersionsCallback(); 160 mPrimaryService.getProcessingVersions(callback); 161 callback.await(3, TimeUnit.SECONDS); 162 assertNotNull(callback.mVersions); 163 assertEquals(callback.mVersions.getProcessingVersion(ProcessingTypes.IMAGE_LABEL), 1); 164 assertEquals(callback.mVersions.getProcessingVersion(ProcessingTypes.IMAGE_OCR_LATIN), 1); 165 } 166 167 private static class TestProcessMediaCallback 168 extends ICognitionProcessMediaCallbackInternal.Stub { 169 170 private CountDownLatch mLatch = new CountDownLatch(1); 171 private CursorWindow[] mWindows; 172 private String mFailureMessage; 173 174 @Override onProcessMediaSuccess(CursorWindow[] cursorWindows)175 public void onProcessMediaSuccess(CursorWindow[] cursorWindows) throws RemoteException { 176 mWindows = cursorWindows; 177 mLatch.countDown(); 178 } 179 180 @Override onProcessMediaFailure(String s)181 public void onProcessMediaFailure(String s) throws RemoteException { 182 mFailureMessage = s; 183 mLatch.countDown(); 184 } 185 await(int time, TimeUnit unit)186 public void await(int time, TimeUnit unit) throws InterruptedException { 187 mLatch.await(time, unit); 188 } 189 190 } 191 192 private static class TestGetVersionsCallback 193 extends ICognitionGetVersionsCallbackInternal.Stub { 194 195 private CountDownLatch mLatch = new CountDownLatch(1); 196 197 private String mFailureMessage; 198 199 private MediaCognitionProcessingVersions mVersions; 200 201 @Override onGetProcessingVersionsSuccess( MediaCognitionProcessingVersions mediaCognitionProcessingVersions)202 public void onGetProcessingVersionsSuccess( 203 MediaCognitionProcessingVersions mediaCognitionProcessingVersions) 204 throws RemoteException { 205 mVersions = mediaCognitionProcessingVersions; 206 mLatch.countDown(); 207 } 208 209 @Override onGetProcessingVersionsFailure(String s)210 public void onGetProcessingVersionsFailure(String s) throws RemoteException { 211 mFailureMessage = s; 212 mLatch.countDown(); 213 } 214 await(int time, TimeUnit unit)215 public void await(int time, TimeUnit unit) throws InterruptedException { 216 mLatch.await(time, unit); 217 } 218 219 } 220 221 } 222