1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package android.apppredictionservice.cts; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertFalse; 20 import static org.junit.Assert.assertNotNull; 21 import static org.junit.Assert.assertTrue; 22 23 import android.app.prediction.AppPredictionContext; 24 import android.app.prediction.AppPredictionSessionId; 25 import android.app.prediction.AppPredictor; 26 import android.app.prediction.AppTarget; 27 import android.app.prediction.AppTargetEvent; 28 import android.app.prediction.AppTargetId; 29 import android.os.Binder; 30 import android.os.Bundle; 31 32 import java.util.ArrayList; 33 import java.util.HashMap; 34 import java.util.List; 35 import java.util.concurrent.CountDownLatch; 36 import java.util.concurrent.TimeUnit; 37 import java.util.function.Consumer; 38 import java.util.function.Supplier; 39 40 /** 41 * Reports calls from the CTS prediction service back to the tests. 42 */ 43 public class ServiceReporter extends Binder { 44 45 public HashMap<AppPredictionSessionId, AppPredictionContext> mSessions = new HashMap<>(); 46 47 public ArrayList<AppTargetEvent> mEvents = new ArrayList<>(); 48 public String mLocationsShown; 49 public ArrayList<AppTargetId> mLocationsShownTargets = new ArrayList<>(); 50 public int mNumRequestedUpdates = 0; 51 public boolean mPredictionUpdatesStarted = false; 52 53 private CountDownLatch mCreateSessionLatch = new CountDownLatch(1); 54 private CountDownLatch mEventLatch = new CountDownLatch(1); 55 private CountDownLatch mLocationShownLatch = new CountDownLatch(1); 56 private CountDownLatch mSortLatch = new CountDownLatch(1); 57 private CountDownLatch mStartPredictionUpdatesLatch = new CountDownLatch(1); 58 private CountDownLatch mStopPredictionUpdatesLatch = new CountDownLatch(1); 59 private CountDownLatch mPredictionUpdateLatch = new CountDownLatch(1); 60 private CountDownLatch mDestroyLatch = new CountDownLatch(1); 61 private CountDownLatch mRequestServiceFeaturesLatch = new CountDownLatch(1); 62 63 private PredictionsProvider mPredictionsProvider; 64 private SortedPredictionsProvider mSortedPredictionsProvider; 65 private Supplier<Bundle> mServiceFeaturesProvider; 66 setPredictionsProvider(PredictionsProvider cb)67 void setPredictionsProvider(PredictionsProvider cb) { 68 mPredictionsProvider = cb; 69 } 70 getPredictionsProvider()71 PredictionsProvider getPredictionsProvider() { 72 return mPredictionsProvider; 73 } 74 setSortedPredictionsProvider(SortedPredictionsProvider cb)75 void setSortedPredictionsProvider(SortedPredictionsProvider cb) { 76 mSortedPredictionsProvider = cb; 77 } 78 getSortedPredictionsProvider()79 SortedPredictionsProvider getSortedPredictionsProvider() { 80 return mSortedPredictionsProvider; 81 } 82 setServiceFeaturesProvider(Supplier<Bundle> cb)83 void setServiceFeaturesProvider(Supplier<Bundle> cb) { 84 mServiceFeaturesProvider = cb; 85 } 86 getServiceFeaturesProvider()87 Supplier<Bundle> getServiceFeaturesProvider() { 88 return mServiceFeaturesProvider; 89 } 90 assertActiveSession(AppPredictionSessionId sessionId)91 void assertActiveSession(AppPredictionSessionId sessionId) { 92 assertTrue(mSessions.containsKey(sessionId)); 93 } 94 getPredictionContext(AppPredictionSessionId sessionId)95 AppPredictionContext getPredictionContext(AppPredictionSessionId sessionId) { 96 assertTrue(mSessions.containsKey(sessionId)); 97 return mSessions.get(sessionId); 98 } 99 onCreatePredictionSession(AppPredictionContext context, AppPredictionSessionId sessionId)100 void onCreatePredictionSession(AppPredictionContext context, 101 AppPredictionSessionId sessionId) { 102 assertNotNull(context); 103 assertNotNull(sessionId); 104 assertFalse(mSessions.containsKey(sessionId)); 105 mSessions.put(sessionId, context); 106 mCreateSessionLatch.countDown(); 107 } 108 awaitOnCreatePredictionSession()109 boolean awaitOnCreatePredictionSession() { 110 try { 111 return await(mCreateSessionLatch); 112 } finally { 113 mCreateSessionLatch = new CountDownLatch(1); 114 } 115 } 116 onAppTargetEvent(AppPredictionSessionId sessionId, AppTargetEvent event)117 void onAppTargetEvent(AppPredictionSessionId sessionId, AppTargetEvent event) { 118 assertTrue(mSessions.containsKey(sessionId)); 119 mEvents.add(event); 120 mEventLatch.countDown(); 121 } 122 awaitOnAppTargetEvent()123 boolean awaitOnAppTargetEvent() { 124 try { 125 return await(mEventLatch); 126 } finally { 127 mEventLatch = new CountDownLatch(1); 128 } 129 } 130 onLocationShown(AppPredictionSessionId sessionId, String launchLocation, List<AppTargetId> targetIds)131 void onLocationShown(AppPredictionSessionId sessionId, String launchLocation, 132 List<AppTargetId> targetIds) { 133 assertTrue(mSessions.containsKey(sessionId)); 134 mLocationsShown = launchLocation; 135 mLocationsShownTargets.addAll(targetIds); 136 mLocationShownLatch.countDown(); 137 } 138 awaitOnLocationShown()139 boolean awaitOnLocationShown() { 140 try { 141 return await(mLocationShownLatch); 142 } finally { 143 mLocationShownLatch = new CountDownLatch(1); 144 } 145 } 146 onSortAppTargets(AppPredictionSessionId sessionId, List<AppTarget> targets, Consumer<List<AppTarget>> callback)147 void onSortAppTargets(AppPredictionSessionId sessionId, List<AppTarget> targets, 148 Consumer<List<AppTarget>> callback) { 149 assertTrue(mSessions.containsKey(sessionId)); 150 assertNotNull(targets); 151 assertNotNull(callback); 152 mSortLatch.countDown(); 153 } 154 awaitOnSortAppTargets()155 boolean awaitOnSortAppTargets() { 156 try { 157 return await(mSortLatch); 158 } finally { 159 mSortLatch = new CountDownLatch(1); 160 } 161 } 162 onStartPredictionUpdates()163 void onStartPredictionUpdates() { 164 mPredictionUpdatesStarted = true; 165 } 166 awaitOnStartPredictionUpdates()167 boolean awaitOnStartPredictionUpdates() { 168 try { 169 return await(mStartPredictionUpdatesLatch); 170 } finally { 171 mStartPredictionUpdatesLatch = new CountDownLatch(1); 172 } 173 } 174 onStopPredictionUpdates()175 void onStopPredictionUpdates() { 176 mPredictionUpdatesStarted = false; 177 } 178 awaitOnStopPredictionUpdates()179 boolean awaitOnStopPredictionUpdates() { 180 try { 181 return await(mStopPredictionUpdatesLatch); 182 } finally { 183 mStopPredictionUpdatesLatch = new CountDownLatch(1); 184 } 185 } 186 onRequestPredictionUpdate(AppPredictionSessionId sessionId)187 void onRequestPredictionUpdate(AppPredictionSessionId sessionId) { 188 assertTrue(mSessions.containsKey(sessionId)); 189 mNumRequestedUpdates++; 190 mPredictionUpdateLatch.countDown(); 191 } 192 awaitOnRequestPredictionUpdate()193 boolean awaitOnRequestPredictionUpdate() { 194 try { 195 return await(mPredictionUpdateLatch); 196 } finally { 197 mPredictionUpdateLatch = new CountDownLatch(1); 198 } 199 } 200 onDestroyPredictionSession(AppPredictionSessionId sessionId)201 void onDestroyPredictionSession(AppPredictionSessionId sessionId) { 202 assertTrue(mSessions.containsKey(sessionId)); 203 mSessions.remove(sessionId); 204 mDestroyLatch.countDown(); 205 } 206 awaitOnDestroyPredictionSession()207 boolean awaitOnDestroyPredictionSession() { 208 try { 209 return await(mDestroyLatch); 210 } finally { 211 mDestroyLatch = new CountDownLatch(1); 212 } 213 } 214 onRequestServiceFeatures(AppPredictionSessionId sessionId, Consumer<Bundle> callback)215 void onRequestServiceFeatures(AppPredictionSessionId sessionId, 216 Consumer<Bundle> callback) { 217 assertTrue(mSessions.containsKey(sessionId)); 218 assertNotNull(callback); 219 mRequestServiceFeaturesLatch.countDown(); 220 } 221 222 public class Event { 223 final AppTarget target; 224 final int launchLocation; 225 final int eventType; 226 Event(AppTarget target, int launchLocation, int eventType)227 public Event(AppTarget target, int launchLocation, int eventType) { 228 this.target = target; 229 this.launchLocation = launchLocation; 230 this.eventType = eventType; 231 } 232 } 233 await(CountDownLatch latch)234 private boolean await(CountDownLatch latch) { 235 try { 236 latch.await(500, TimeUnit.MILLISECONDS); 237 return true; 238 } catch (InterruptedException e) { 239 return false; 240 } 241 } 242 243 public static class RequestVerifier implements AppPredictor.Callback, PredictionsProvider, 244 Consumer<List<AppTarget>> { 245 246 private ServiceReporter mReporter; 247 private CountDownLatch mReceivedLatch; 248 private List<AppTarget> mTargets; 249 RequestVerifier(ServiceReporter reporter)250 public RequestVerifier(ServiceReporter reporter) { 251 mReporter = reporter; 252 mReceivedLatch = new CountDownLatch(1); 253 } 254 255 @Override getTargets(AppPredictionSessionId sessionId)256 public List<AppTarget> getTargets(AppPredictionSessionId sessionId) { 257 return mTargets; 258 } 259 260 @Override onTargetsAvailable(List<AppTarget> targets)261 public void onTargetsAvailable(List<AppTarget> targets) { 262 if (mTargets != null) { 263 // Verify that the targets match 264 assertEquals(targets, mTargets); 265 } else { 266 // For the case where we didn't setup the request, save the targets so we can verify 267 // them in awaitTargets() 268 mTargets = targets; 269 } 270 mReceivedLatch.countDown(); 271 } 272 273 @Override accept(List<AppTarget> appTargets)274 public void accept(List<AppTarget> appTargets) { 275 onTargetsAvailable(appTargets); 276 } 277 278 /** 279 * @param requestUpdateCb Callback called when the request is setup 280 */ requestAndWaitForTargets(List<AppTarget> targets, Runnable requestUpdateCb)281 boolean requestAndWaitForTargets(List<AppTarget> targets, Runnable requestUpdateCb) { 282 mTargets = targets; 283 mReceivedLatch = new CountDownLatch(1); 284 mReporter.setPredictionsProvider(this); 285 requestUpdateCb.run(); 286 try { 287 return awaitTargets(targets); 288 } finally { 289 mReporter.setPredictionsProvider(null); 290 } 291 } 292 awaitTargets(List<AppTarget> targets)293 boolean awaitTargets(List<AppTarget> targets) { 294 try { 295 boolean result = mReceivedLatch.await(500, TimeUnit.MILLISECONDS); 296 assertEquals(targets, mTargets); 297 return result; 298 } catch (InterruptedException e) { 299 return false; 300 } 301 } 302 } 303 304 public static class RequestServiceFeaturesVerifier implements Consumer<Bundle> { 305 306 private CountDownLatch mReceivedLatch; 307 private Bundle mBundle; 308 309 @Override accept(Bundle bundle)310 public void accept(Bundle bundle) { 311 mBundle = bundle; 312 mReceivedLatch.countDown(); 313 } 314 315 /** 316 * @param requestCb Callback called when the request is setup 317 */ requestAndWaitForTargets(Bundle bundle, Runnable requestCb)318 boolean requestAndWaitForTargets(Bundle bundle, Runnable requestCb) { 319 mReceivedLatch = new CountDownLatch(1); 320 requestCb.run(); 321 return awaitTargets(bundle); 322 } 323 awaitTargets(Bundle bundle)324 boolean awaitTargets(Bundle bundle) { 325 try { 326 boolean result = mReceivedLatch.await(500, TimeUnit.MILLISECONDS); 327 assertEquals(bundle, mBundle); 328 return result; 329 } catch (InterruptedException e) { 330 return false; 331 } 332 } 333 } 334 } 335