• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.apppredictionservice.cts;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertFalse;
20 import static org.junit.Assert.assertNotNull;
21 import static org.junit.Assert.assertTrue;
22 
23 import android.app.prediction.AppPredictionContext;
24 import android.app.prediction.AppPredictionSessionId;
25 import android.app.prediction.AppPredictor;
26 import android.app.prediction.AppTarget;
27 import android.app.prediction.AppTargetEvent;
28 import android.app.prediction.AppTargetId;
29 import android.os.Binder;
30 import android.os.Bundle;
31 
32 import java.util.ArrayList;
33 import java.util.HashMap;
34 import java.util.List;
35 import java.util.concurrent.CountDownLatch;
36 import java.util.concurrent.TimeUnit;
37 import java.util.function.Consumer;
38 import java.util.function.Supplier;
39 
40 /**
41  * Reports calls from the CTS prediction service back to the tests.
42  */
43 public class ServiceReporter extends Binder {
44 
45     public HashMap<AppPredictionSessionId, AppPredictionContext> mSessions = new HashMap<>();
46 
47     public ArrayList<AppTargetEvent> mEvents = new ArrayList<>();
48     public String mLocationsShown;
49     public ArrayList<AppTargetId> mLocationsShownTargets = new ArrayList<>();
50     public int mNumRequestedUpdates = 0;
51     public boolean mPredictionUpdatesStarted = false;
52 
53     private CountDownLatch mCreateSessionLatch = new CountDownLatch(1);
54     private CountDownLatch mEventLatch = new CountDownLatch(1);
55     private CountDownLatch mLocationShownLatch = new CountDownLatch(1);
56     private CountDownLatch mSortLatch = new CountDownLatch(1);
57     private CountDownLatch mStartPredictionUpdatesLatch = new CountDownLatch(1);
58     private CountDownLatch mStopPredictionUpdatesLatch = new CountDownLatch(1);
59     private CountDownLatch mPredictionUpdateLatch = new CountDownLatch(1);
60     private CountDownLatch mDestroyLatch = new CountDownLatch(1);
61     private CountDownLatch mRequestServiceFeaturesLatch = new CountDownLatch(1);
62 
63     private PredictionsProvider mPredictionsProvider;
64     private SortedPredictionsProvider mSortedPredictionsProvider;
65     private Supplier<Bundle> mServiceFeaturesProvider;
66 
setPredictionsProvider(PredictionsProvider cb)67     void setPredictionsProvider(PredictionsProvider cb) {
68         mPredictionsProvider = cb;
69     }
70 
getPredictionsProvider()71     PredictionsProvider getPredictionsProvider() {
72         return mPredictionsProvider;
73     }
74 
setSortedPredictionsProvider(SortedPredictionsProvider cb)75     void setSortedPredictionsProvider(SortedPredictionsProvider cb) {
76         mSortedPredictionsProvider = cb;
77     }
78 
getSortedPredictionsProvider()79     SortedPredictionsProvider getSortedPredictionsProvider() {
80         return mSortedPredictionsProvider;
81     }
82 
setServiceFeaturesProvider(Supplier<Bundle> cb)83     void setServiceFeaturesProvider(Supplier<Bundle> cb) {
84         mServiceFeaturesProvider = cb;
85     }
86 
getServiceFeaturesProvider()87     Supplier<Bundle> getServiceFeaturesProvider() {
88         return mServiceFeaturesProvider;
89     }
90 
assertActiveSession(AppPredictionSessionId sessionId)91     void assertActiveSession(AppPredictionSessionId sessionId) {
92         assertTrue(mSessions.containsKey(sessionId));
93     }
94 
getPredictionContext(AppPredictionSessionId sessionId)95     AppPredictionContext getPredictionContext(AppPredictionSessionId sessionId) {
96         assertTrue(mSessions.containsKey(sessionId));
97         return mSessions.get(sessionId);
98     }
99 
onCreatePredictionSession(AppPredictionContext context, AppPredictionSessionId sessionId)100     void onCreatePredictionSession(AppPredictionContext context,
101             AppPredictionSessionId sessionId) {
102         assertNotNull(context);
103         assertNotNull(sessionId);
104         assertFalse(mSessions.containsKey(sessionId));
105         mSessions.put(sessionId, context);
106         mCreateSessionLatch.countDown();
107     }
108 
awaitOnCreatePredictionSession()109     boolean awaitOnCreatePredictionSession() {
110         try {
111             return await(mCreateSessionLatch);
112         } finally {
113             mCreateSessionLatch = new CountDownLatch(1);
114         }
115     }
116 
onAppTargetEvent(AppPredictionSessionId sessionId, AppTargetEvent event)117     void onAppTargetEvent(AppPredictionSessionId sessionId, AppTargetEvent event) {
118         assertTrue(mSessions.containsKey(sessionId));
119         mEvents.add(event);
120         mEventLatch.countDown();
121     }
122 
awaitOnAppTargetEvent()123     boolean awaitOnAppTargetEvent() {
124         try {
125             return await(mEventLatch);
126         } finally {
127             mEventLatch = new CountDownLatch(1);
128         }
129     }
130 
onLocationShown(AppPredictionSessionId sessionId, String launchLocation, List<AppTargetId> targetIds)131     void onLocationShown(AppPredictionSessionId sessionId, String launchLocation,
132             List<AppTargetId> targetIds) {
133         assertTrue(mSessions.containsKey(sessionId));
134         mLocationsShown = launchLocation;
135         mLocationsShownTargets.addAll(targetIds);
136         mLocationShownLatch.countDown();
137     }
138 
awaitOnLocationShown()139     boolean awaitOnLocationShown() {
140         try {
141             return await(mLocationShownLatch);
142         } finally {
143             mLocationShownLatch = new CountDownLatch(1);
144         }
145     }
146 
onSortAppTargets(AppPredictionSessionId sessionId, List<AppTarget> targets, Consumer<List<AppTarget>> callback)147     void onSortAppTargets(AppPredictionSessionId sessionId, List<AppTarget> targets,
148             Consumer<List<AppTarget>> callback) {
149         assertTrue(mSessions.containsKey(sessionId));
150         assertNotNull(targets);
151         assertNotNull(callback);
152         mSortLatch.countDown();
153     }
154 
awaitOnSortAppTargets()155     boolean awaitOnSortAppTargets() {
156         try {
157             return await(mSortLatch);
158         } finally {
159             mSortLatch = new CountDownLatch(1);
160         }
161     }
162 
onStartPredictionUpdates()163     void onStartPredictionUpdates() {
164         mPredictionUpdatesStarted = true;
165     }
166 
awaitOnStartPredictionUpdates()167     boolean awaitOnStartPredictionUpdates() {
168         try {
169             return await(mStartPredictionUpdatesLatch);
170         } finally {
171             mStartPredictionUpdatesLatch = new CountDownLatch(1);
172         }
173     }
174 
onStopPredictionUpdates()175     void onStopPredictionUpdates() {
176         mPredictionUpdatesStarted = false;
177     }
178 
awaitOnStopPredictionUpdates()179     boolean awaitOnStopPredictionUpdates() {
180         try {
181             return await(mStopPredictionUpdatesLatch);
182         } finally {
183             mStopPredictionUpdatesLatch = new CountDownLatch(1);
184         }
185     }
186 
onRequestPredictionUpdate(AppPredictionSessionId sessionId)187     void onRequestPredictionUpdate(AppPredictionSessionId sessionId) {
188         assertTrue(mSessions.containsKey(sessionId));
189         mNumRequestedUpdates++;
190         mPredictionUpdateLatch.countDown();
191     }
192 
awaitOnRequestPredictionUpdate()193     boolean awaitOnRequestPredictionUpdate() {
194         try {
195             return await(mPredictionUpdateLatch);
196         } finally {
197             mPredictionUpdateLatch = new CountDownLatch(1);
198         }
199     }
200 
onDestroyPredictionSession(AppPredictionSessionId sessionId)201     void onDestroyPredictionSession(AppPredictionSessionId sessionId) {
202         assertTrue(mSessions.containsKey(sessionId));
203         mSessions.remove(sessionId);
204         mDestroyLatch.countDown();
205     }
206 
awaitOnDestroyPredictionSession()207     boolean awaitOnDestroyPredictionSession() {
208         try {
209             return await(mDestroyLatch);
210         } finally {
211             mDestroyLatch = new CountDownLatch(1);
212         }
213     }
214 
onRequestServiceFeatures(AppPredictionSessionId sessionId, Consumer<Bundle> callback)215     void onRequestServiceFeatures(AppPredictionSessionId sessionId,
216             Consumer<Bundle> callback) {
217         assertTrue(mSessions.containsKey(sessionId));
218         assertNotNull(callback);
219         mRequestServiceFeaturesLatch.countDown();
220     }
221 
222     public class Event {
223         final AppTarget target;
224         final int launchLocation;
225         final int eventType;
226 
Event(AppTarget target, int launchLocation, int eventType)227         public Event(AppTarget target, int launchLocation, int eventType) {
228             this.target = target;
229             this.launchLocation = launchLocation;
230             this.eventType = eventType;
231         }
232     }
233 
await(CountDownLatch latch)234     private boolean await(CountDownLatch latch) {
235         try {
236             latch.await(500, TimeUnit.MILLISECONDS);
237             return true;
238         } catch (InterruptedException e) {
239             return false;
240         }
241     }
242 
243     public static class RequestVerifier implements AppPredictor.Callback, PredictionsProvider,
244             Consumer<List<AppTarget>> {
245 
246         private ServiceReporter mReporter;
247         private CountDownLatch mReceivedLatch;
248         private List<AppTarget> mTargets;
249 
RequestVerifier(ServiceReporter reporter)250         public RequestVerifier(ServiceReporter reporter) {
251             mReporter = reporter;
252             mReceivedLatch = new CountDownLatch(1);
253         }
254 
255         @Override
getTargets(AppPredictionSessionId sessionId)256         public List<AppTarget> getTargets(AppPredictionSessionId sessionId) {
257             return mTargets;
258         }
259 
260         @Override
onTargetsAvailable(List<AppTarget> targets)261         public void onTargetsAvailable(List<AppTarget> targets) {
262             if (mTargets != null) {
263                 // Verify that the targets match
264                 assertEquals(targets, mTargets);
265             } else {
266                 // For the case where we didn't setup the request, save the targets so we can verify
267                 // them in awaitTargets()
268                 mTargets = targets;
269             }
270             mReceivedLatch.countDown();
271         }
272 
273         @Override
accept(List<AppTarget> appTargets)274         public void accept(List<AppTarget> appTargets) {
275             onTargetsAvailable(appTargets);
276         }
277 
278         /**
279          * @param requestUpdateCb Callback called when the request is setup
280          */
requestAndWaitForTargets(List<AppTarget> targets, Runnable requestUpdateCb)281         boolean requestAndWaitForTargets(List<AppTarget> targets, Runnable requestUpdateCb) {
282             mTargets = targets;
283             mReceivedLatch = new CountDownLatch(1);
284             mReporter.setPredictionsProvider(this);
285             requestUpdateCb.run();
286             try {
287                 return awaitTargets(targets);
288             } finally {
289                 mReporter.setPredictionsProvider(null);
290             }
291         }
292 
awaitTargets(List<AppTarget> targets)293         boolean awaitTargets(List<AppTarget> targets) {
294             try {
295                 boolean result = mReceivedLatch.await(500, TimeUnit.MILLISECONDS);
296                 assertEquals(targets, mTargets);
297                 return result;
298             } catch (InterruptedException e) {
299                 return false;
300             }
301         }
302     }
303 
304     public static class RequestServiceFeaturesVerifier implements Consumer<Bundle> {
305 
306         private CountDownLatch mReceivedLatch;
307         private Bundle mBundle;
308 
309         @Override
accept(Bundle bundle)310         public void accept(Bundle bundle) {
311             mBundle = bundle;
312             mReceivedLatch.countDown();
313         }
314 
315         /**
316          * @param requestCb Callback called when the request is setup
317          */
requestAndWaitForTargets(Bundle bundle, Runnable requestCb)318         boolean requestAndWaitForTargets(Bundle bundle, Runnable requestCb) {
319             mReceivedLatch = new CountDownLatch(1);
320             requestCb.run();
321             return awaitTargets(bundle);
322         }
323 
awaitTargets(Bundle bundle)324         boolean awaitTargets(Bundle bundle) {
325             try {
326                 boolean result = mReceivedLatch.await(500, TimeUnit.MILLISECONDS);
327                 assertEquals(bundle, mBundle);
328                 return result;
329             } catch (InterruptedException e) {
330                 return false;
331             }
332         }
333     }
334 }
335