1 /* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18 package com.android.intentresolver.model; 19 20 import android.app.usage.UsageStats; 21 import android.content.ComponentName; 22 import android.content.Context; 23 import android.content.Intent; 24 import android.content.ServiceConnection; 25 import android.content.pm.ApplicationInfo; 26 import android.content.pm.PackageManager; 27 import android.content.pm.PackageManager.NameNotFoundException; 28 import android.content.pm.ResolveInfo; 29 import android.metrics.LogMaker; 30 import android.os.IBinder; 31 import android.os.Message; 32 import android.os.RemoteException; 33 import android.os.UserHandle; 34 import android.service.resolver.IResolverRankerResult; 35 import android.service.resolver.IResolverRankerService; 36 import android.service.resolver.ResolverRankerService; 37 import android.service.resolver.ResolverTarget; 38 import android.util.Log; 39 40 import com.android.intentresolver.ChooserActivityLogger; 41 import com.android.intentresolver.ResolvedComponentInfo; 42 import com.android.internal.logging.MetricsLogger; 43 import com.android.internal.logging.nano.MetricsProto.MetricsEvent; 44 45 import java.text.Collator; 46 import java.util.ArrayList; 47 import java.util.Comparator; 48 import java.util.LinkedHashMap; 49 import java.util.List; 50 import java.util.Map; 51 import java.util.concurrent.CountDownLatch; 52 import java.util.concurrent.TimeUnit; 53 54 /** 55 * Ranks and compares packages based on usage stats and uses the {@link ResolverRankerService}. 56 */ 57 public class ResolverRankerServiceResolverComparator extends AbstractResolverComparator { 58 private static final String TAG = "RRSResolverComparator"; 59 60 private static final boolean DEBUG = false; 61 62 // One week 63 private static final long USAGE_STATS_PERIOD = 1000 * 60 * 60 * 24 * 7; 64 65 private static final long RECENCY_TIME_PERIOD = 1000 * 60 * 60 * 12; 66 67 private static final float RECENCY_MULTIPLIER = 2.f; 68 69 // timeout for establishing connections with a ResolverRankerService. 70 private static final int CONNECTION_COST_TIMEOUT_MILLIS = 200; 71 72 private final Collator mCollator; 73 private final Map<String, UsageStats> mStats; 74 private final long mCurrentTime; 75 private final long mSinceTime; 76 private final LinkedHashMap<ComponentName, ResolverTarget> mTargetsDict = new LinkedHashMap<>(); 77 private final String mReferrerPackage; 78 private final Object mLock = new Object(); 79 private ArrayList<ResolverTarget> mTargets; 80 private String mAction; 81 private ComponentName mResolvedRankerName; 82 private ComponentName mRankerServiceName; 83 private IResolverRankerService mRanker; 84 private ResolverRankerServiceConnection mConnection; 85 private Context mContext; 86 private CountDownLatch mConnectSignal; 87 private ResolverRankerServiceComparatorModel mComparatorModel; 88 ResolverRankerServiceResolverComparator(Context context, Intent intent, String referrerPackage, Runnable afterCompute, ChooserActivityLogger chooserActivityLogger)89 public ResolverRankerServiceResolverComparator(Context context, Intent intent, 90 String referrerPackage, Runnable afterCompute, 91 ChooserActivityLogger chooserActivityLogger) { 92 super(context, intent); 93 mCollator = Collator.getInstance(context.getResources().getConfiguration().locale); 94 mReferrerPackage = referrerPackage; 95 mContext = context; 96 97 mCurrentTime = System.currentTimeMillis(); 98 mSinceTime = mCurrentTime - USAGE_STATS_PERIOD; 99 mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime); 100 mAction = intent.getAction(); 101 mRankerServiceName = new ComponentName(mContext, this.getClass()); 102 setCallBack(afterCompute); 103 setChooserActivityLogger(chooserActivityLogger); 104 105 mComparatorModel = buildUpdatedModel(); 106 } 107 108 @Override handleResultMessage(Message msg)109 public void handleResultMessage(Message msg) { 110 if (msg.what != RANKER_SERVICE_RESULT) { 111 return; 112 } 113 if (msg.obj == null) { 114 Log.e(TAG, "Receiving null prediction results."); 115 return; 116 } 117 final List<ResolverTarget> receivedTargets = (List<ResolverTarget>) msg.obj; 118 if (receivedTargets != null && mTargets != null 119 && receivedTargets.size() == mTargets.size()) { 120 final int size = mTargets.size(); 121 boolean isUpdated = false; 122 for (int i = 0; i < size; ++i) { 123 final float predictedProb = 124 receivedTargets.get(i).getSelectProbability(); 125 if (predictedProb != mTargets.get(i).getSelectProbability()) { 126 mTargets.get(i).setSelectProbability(predictedProb); 127 isUpdated = true; 128 } 129 } 130 if (isUpdated) { 131 mRankerServiceName = mResolvedRankerName; 132 mComparatorModel = buildUpdatedModel(); 133 } 134 } else { 135 Log.e(TAG, "Sizes of sent and received ResolverTargets diff."); 136 } 137 } 138 139 // compute features for each target according to usage stats of targets. 140 @Override doCompute(List<ResolvedComponentInfo> targets)141 public void doCompute(List<ResolvedComponentInfo> targets) { 142 final long recentSinceTime = mCurrentTime - RECENCY_TIME_PERIOD; 143 144 float mostRecencyScore = 1.0f; 145 float mostTimeSpentScore = 1.0f; 146 float mostLaunchScore = 1.0f; 147 float mostChooserScore = 1.0f; 148 149 for (ResolvedComponentInfo target : targets) { 150 final ResolverTarget resolverTarget = new ResolverTarget(); 151 mTargetsDict.put(target.name, resolverTarget); 152 final UsageStats pkStats = mStats.get(target.name.getPackageName()); 153 if (pkStats != null) { 154 // Only count recency for apps that weren't the caller 155 // since the caller is always the most recent. 156 // Persistent processes muck this up, so omit them too. 157 if (!target.name.getPackageName().equals(mReferrerPackage) 158 && !isPersistentProcess(target)) { 159 final float recencyScore = 160 (float) Math.max(pkStats.getLastTimeUsed() - recentSinceTime, 0); 161 resolverTarget.setRecencyScore(recencyScore); 162 if (recencyScore > mostRecencyScore) { 163 mostRecencyScore = recencyScore; 164 } 165 } 166 final float timeSpentScore = (float) pkStats.getTotalTimeInForeground(); 167 resolverTarget.setTimeSpentScore(timeSpentScore); 168 if (timeSpentScore > mostTimeSpentScore) { 169 mostTimeSpentScore = timeSpentScore; 170 } 171 final float launchScore = (float) pkStats.mLaunchCount; 172 resolverTarget.setLaunchScore(launchScore); 173 if (launchScore > mostLaunchScore) { 174 mostLaunchScore = launchScore; 175 } 176 177 float chooserScore = 0.0f; 178 if (pkStats.mChooserCounts != null && mAction != null 179 && pkStats.mChooserCounts.get(mAction) != null) { 180 chooserScore = (float) pkStats.mChooserCounts.get(mAction) 181 .getOrDefault(mContentType, 0); 182 if (mAnnotations != null) { 183 final int size = mAnnotations.length; 184 for (int i = 0; i < size; i++) { 185 chooserScore += (float) pkStats.mChooserCounts.get(mAction) 186 .getOrDefault(mAnnotations[i], 0); 187 } 188 } 189 } 190 if (DEBUG) { 191 if (mAction == null) { 192 Log.d(TAG, "Action type is null"); 193 } else { 194 Log.d(TAG, "Chooser Count of " + mAction + ":" 195 + target.name.getPackageName() + " is " 196 + Float.toString(chooserScore)); 197 } 198 } 199 resolverTarget.setChooserScore(chooserScore); 200 if (chooserScore > mostChooserScore) { 201 mostChooserScore = chooserScore; 202 } 203 } 204 } 205 206 if (DEBUG) { 207 Log.d(TAG, "compute - mostRecencyScore: " + mostRecencyScore 208 + " mostTimeSpentScore: " + mostTimeSpentScore 209 + " mostLaunchScore: " + mostLaunchScore 210 + " mostChooserScore: " + mostChooserScore); 211 } 212 213 mTargets = new ArrayList<>(mTargetsDict.values()); 214 for (ResolverTarget target : mTargets) { 215 final float recency = target.getRecencyScore() / mostRecencyScore; 216 setFeatures(target, recency * recency * RECENCY_MULTIPLIER, 217 target.getLaunchScore() / mostLaunchScore, 218 target.getTimeSpentScore() / mostTimeSpentScore, 219 target.getChooserScore() / mostChooserScore); 220 addDefaultSelectProbability(target); 221 if (DEBUG) { 222 Log.d(TAG, "Scores: " + target); 223 } 224 } 225 predictSelectProbabilities(mTargets); 226 227 mComparatorModel = buildUpdatedModel(); 228 } 229 230 @Override compare(ResolveInfo lhs, ResolveInfo rhs)231 public int compare(ResolveInfo lhs, ResolveInfo rhs) { 232 return mComparatorModel.getComparator().compare(lhs, rhs); 233 } 234 235 @Override getScore(ComponentName name)236 public float getScore(ComponentName name) { 237 return mComparatorModel.getScore(name); 238 } 239 240 // update ranking model when the connection to it is valid. 241 @Override updateModel(ComponentName componentName)242 public void updateModel(ComponentName componentName) { 243 synchronized (mLock) { 244 mComparatorModel.notifyOnTargetSelected(componentName); 245 } 246 } 247 248 // unbind the service and clear unhandled messges. 249 @Override destroy()250 public void destroy() { 251 mHandler.removeMessages(RANKER_SERVICE_RESULT); 252 mHandler.removeMessages(RANKER_RESULT_TIMEOUT); 253 if (mConnection != null) { 254 mContext.unbindService(mConnection); 255 mConnection.destroy(); 256 } 257 afterCompute(); 258 if (DEBUG) { 259 Log.d(TAG, "Unbinded Resolver Ranker."); 260 } 261 } 262 263 // connect to a ranking service. initRanker(Context context)264 private void initRanker(Context context) { 265 synchronized (mLock) { 266 if (mConnection != null && mRanker != null) { 267 if (DEBUG) { 268 Log.d(TAG, "Ranker still exists; reusing the existing one."); 269 } 270 return; 271 } 272 } 273 Intent intent = resolveRankerService(); 274 if (intent == null) { 275 return; 276 } 277 mConnectSignal = new CountDownLatch(1); 278 mConnection = new ResolverRankerServiceConnection(mConnectSignal); 279 context.bindServiceAsUser(intent, mConnection, Context.BIND_AUTO_CREATE, UserHandle.SYSTEM); 280 } 281 282 // resolve the service for ranking. resolveRankerService()283 private Intent resolveRankerService() { 284 Intent intent = new Intent(ResolverRankerService.SERVICE_INTERFACE); 285 final List<ResolveInfo> resolveInfos = mPm.queryIntentServices(intent, 0); 286 for (ResolveInfo resolveInfo : resolveInfos) { 287 if (resolveInfo == null || resolveInfo.serviceInfo == null 288 || resolveInfo.serviceInfo.applicationInfo == null) { 289 if (DEBUG) { 290 Log.d(TAG, "Failed to retrieve a ranker: " + resolveInfo); 291 } 292 continue; 293 } 294 ComponentName componentName = new ComponentName( 295 resolveInfo.serviceInfo.applicationInfo.packageName, 296 resolveInfo.serviceInfo.name); 297 try { 298 final String perm = mPm.getServiceInfo(componentName, 0).permission; 299 if (!ResolverRankerService.BIND_PERMISSION.equals(perm)) { 300 Log.w(TAG, "ResolverRankerService " + componentName + " does not require" 301 + " permission " + ResolverRankerService.BIND_PERMISSION 302 + " - this service will not be queried for " 303 + "ResolverRankerServiceResolverComparator. add android:permission=\"" 304 + ResolverRankerService.BIND_PERMISSION + "\"" 305 + " to the <service> tag for " + componentName 306 + " in the manifest."); 307 continue; 308 } 309 if (PackageManager.PERMISSION_GRANTED != mPm.checkPermission( 310 ResolverRankerService.HOLD_PERMISSION, 311 resolveInfo.serviceInfo.packageName)) { 312 Log.w(TAG, "ResolverRankerService " + componentName + " does not hold" 313 + " permission " + ResolverRankerService.HOLD_PERMISSION 314 + " - this service will not be queried for " 315 + "ResolverRankerServiceResolverComparator."); 316 continue; 317 } 318 } catch (NameNotFoundException e) { 319 Log.e(TAG, "Could not look up service " + componentName 320 + "; component name not found"); 321 continue; 322 } 323 if (DEBUG) { 324 Log.d(TAG, "Succeeded to retrieve a ranker: " + componentName); 325 } 326 mResolvedRankerName = componentName; 327 intent.setComponent(componentName); 328 return intent; 329 } 330 return null; 331 } 332 333 private class ResolverRankerServiceConnection implements ServiceConnection { 334 private final CountDownLatch mConnectSignal; 335 ResolverRankerServiceConnection(CountDownLatch connectSignal)336 ResolverRankerServiceConnection(CountDownLatch connectSignal) { 337 mConnectSignal = connectSignal; 338 } 339 340 public final IResolverRankerResult resolverRankerResult = 341 new IResolverRankerResult.Stub() { 342 @Override 343 public void sendResult(List<ResolverTarget> targets) throws RemoteException { 344 if (DEBUG) { 345 Log.d(TAG, "Sending Result back to Resolver: " + targets); 346 } 347 synchronized (mLock) { 348 final Message msg = Message.obtain(); 349 msg.what = RANKER_SERVICE_RESULT; 350 msg.obj = targets; 351 mHandler.sendMessage(msg); 352 } 353 } 354 }; 355 356 @Override onServiceConnected(ComponentName name, IBinder service)357 public void onServiceConnected(ComponentName name, IBinder service) { 358 if (DEBUG) { 359 Log.d(TAG, "onServiceConnected: " + name); 360 } 361 synchronized (mLock) { 362 mRanker = IResolverRankerService.Stub.asInterface(service); 363 mComparatorModel = buildUpdatedModel(); 364 mConnectSignal.countDown(); 365 } 366 } 367 368 @Override onServiceDisconnected(ComponentName name)369 public void onServiceDisconnected(ComponentName name) { 370 if (DEBUG) { 371 Log.d(TAG, "onServiceDisconnected: " + name); 372 } 373 synchronized (mLock) { 374 destroy(); 375 } 376 } 377 destroy()378 public void destroy() { 379 synchronized (mLock) { 380 mRanker = null; 381 mComparatorModel = buildUpdatedModel(); 382 } 383 } 384 } 385 386 @Override beforeCompute()387 void beforeCompute() { 388 super.beforeCompute(); 389 mTargetsDict.clear(); 390 mTargets = null; 391 mRankerServiceName = new ComponentName(mContext, this.getClass()); 392 mComparatorModel = buildUpdatedModel(); 393 mResolvedRankerName = null; 394 initRanker(mContext); 395 } 396 397 // predict select probabilities if ranking service is valid. predictSelectProbabilities(List<ResolverTarget> targets)398 private void predictSelectProbabilities(List<ResolverTarget> targets) { 399 if (mConnection == null) { 400 if (DEBUG) { 401 Log.d(TAG, "Has not found valid ResolverRankerService; Skip Prediction"); 402 } 403 } else { 404 try { 405 mConnectSignal.await(CONNECTION_COST_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); 406 synchronized (mLock) { 407 if (mRanker != null) { 408 mRanker.predict(targets, mConnection.resolverRankerResult); 409 return; 410 } else { 411 if (DEBUG) { 412 Log.d(TAG, "Ranker has not been initialized; skip predict."); 413 } 414 } 415 } 416 } catch (InterruptedException e) { 417 Log.e(TAG, "Error in Wait for Service Connection."); 418 } catch (RemoteException e) { 419 Log.e(TAG, "Error in Predict: " + e); 420 } 421 } 422 afterCompute(); 423 } 424 425 // adds select prob as the default values, according to a pre-trained Logistic Regression model. addDefaultSelectProbability(ResolverTarget target)426 private void addDefaultSelectProbability(ResolverTarget target) { 427 float sum = (2.5543f * target.getLaunchScore()) 428 + (2.8412f * target.getTimeSpentScore()) 429 + (0.269f * target.getRecencyScore()) 430 + (4.2222f * target.getChooserScore()); 431 target.setSelectProbability((float) (1.0 / (1.0 + Math.exp(1.6568f - sum)))); 432 } 433 434 // sets features for each target setFeatures(ResolverTarget target, float recencyScore, float launchScore, float timeSpentScore, float chooserScore)435 private void setFeatures(ResolverTarget target, float recencyScore, float launchScore, 436 float timeSpentScore, float chooserScore) { 437 target.setRecencyScore(recencyScore); 438 target.setLaunchScore(launchScore); 439 target.setTimeSpentScore(timeSpentScore); 440 target.setChooserScore(chooserScore); 441 } 442 isPersistentProcess(ResolvedComponentInfo rci)443 static boolean isPersistentProcess(ResolvedComponentInfo rci) { 444 if (rci != null && rci.getCount() > 0) { 445 int flags = rci.getResolveInfoAt(0).activityInfo.applicationInfo.flags; 446 return (flags & ApplicationInfo.FLAG_PERSISTENT) != 0; 447 } 448 return false; 449 } 450 451 /** 452 * Re-construct a {@code ResolverRankerServiceComparatorModel} to replace the current model 453 * instance (if any) using the up-to-date {@code ResolverRankerServiceResolverComparator} ivar 454 * values. 455 * 456 * TODO: each time we replace the model instance, we're either updating the model to use 457 * adjusted data (which is appropriate), or we're providing a (late) value for one of our ivars 458 * that wasn't available the last time the model was updated. For those latter cases, we should 459 * just avoid creating the model altogether until we have all the prerequisites we'll need. Then 460 * we can probably simplify the logic in {@code ResolverRankerServiceComparatorModel} since we 461 * won't need to handle edge cases when the model data isn't fully prepared. 462 * (In some cases, these kinds of "updates" might interleave -- e.g., we might have finished 463 * initializing the first time and now want to adjust some data, but still need to wait for 464 * changes to propagate to the other ivars before rebuilding the model.) 465 */ buildUpdatedModel()466 private ResolverRankerServiceComparatorModel buildUpdatedModel() { 467 // TODO: we don't currently guarantee that the underlying target list/map won't be mutated, 468 // so the ResolverComparatorModel may provide inconsistent results. We should make immutable 469 // copies of the data (waiting for any necessary remaining data before creating the model). 470 return new ResolverRankerServiceComparatorModel( 471 mStats, 472 mTargetsDict, 473 mTargets, 474 mCollator, 475 mRanker, 476 mRankerServiceName, 477 (mAnnotations != null), 478 mPm); 479 } 480 481 /** 482 * Implementation of a {@code ResolverComparatorModel} that provides the same ranking logic as 483 * the legacy {@code ResolverRankerServiceResolverComparator}, as a refactoring step toward 484 * removing the complex legacy API. 485 */ 486 static class ResolverRankerServiceComparatorModel implements ResolverComparatorModel { 487 private final Map<String, UsageStats> mStats; // Treat as immutable. 488 private final Map<ComponentName, ResolverTarget> mTargetsDict; // Treat as immutable. 489 private final List<ResolverTarget> mTargets; // Treat as immutable. 490 private final Collator mCollator; 491 private final IResolverRankerService mRanker; 492 private final ComponentName mRankerServiceName; 493 private final boolean mAnnotationsUsed; 494 private final PackageManager mPm; 495 496 // TODO: it doesn't look like we should have to pass both targets and targetsDict, but it's 497 // not written in a way that makes it clear whether we can derive one from the other (at 498 // least in this constructor). ResolverRankerServiceComparatorModel( Map<String, UsageStats> stats, Map<ComponentName, ResolverTarget> targetsDict, List<ResolverTarget> targets, Collator collator, IResolverRankerService ranker, ComponentName rankerServiceName, boolean annotationsUsed, PackageManager pm)499 ResolverRankerServiceComparatorModel( 500 Map<String, UsageStats> stats, 501 Map<ComponentName, ResolverTarget> targetsDict, 502 List<ResolverTarget> targets, 503 Collator collator, 504 IResolverRankerService ranker, 505 ComponentName rankerServiceName, 506 boolean annotationsUsed, 507 PackageManager pm) { 508 mStats = stats; 509 mTargetsDict = targetsDict; 510 mTargets = targets; 511 mCollator = collator; 512 mRanker = ranker; 513 mRankerServiceName = rankerServiceName; 514 mAnnotationsUsed = annotationsUsed; 515 mPm = pm; 516 } 517 518 @Override getComparator()519 public Comparator<ResolveInfo> getComparator() { 520 // TODO: doCompute() doesn't seem to be concerned about null-checking mStats. Is that 521 // a bug there, or do we have a way of knowing it will be non-null under certain 522 // conditions? 523 return (lhs, rhs) -> { 524 if (mStats != null) { 525 final ResolverTarget lhsTarget = mTargetsDict.get(new ComponentName( 526 lhs.activityInfo.packageName, lhs.activityInfo.name)); 527 final ResolverTarget rhsTarget = mTargetsDict.get(new ComponentName( 528 rhs.activityInfo.packageName, rhs.activityInfo.name)); 529 530 if (lhsTarget != null && rhsTarget != null) { 531 final int selectProbabilityDiff = Float.compare( 532 rhsTarget.getSelectProbability(), lhsTarget.getSelectProbability()); 533 534 if (selectProbabilityDiff != 0) { 535 return selectProbabilityDiff > 0 ? 1 : -1; 536 } 537 } 538 } 539 540 CharSequence sa = lhs.loadLabel(mPm); 541 if (sa == null) sa = lhs.activityInfo.name; 542 CharSequence sb = rhs.loadLabel(mPm); 543 if (sb == null) sb = rhs.activityInfo.name; 544 545 return mCollator.compare(sa.toString().trim(), sb.toString().trim()); 546 }; 547 } 548 549 @Override getScore(ComponentName name)550 public float getScore(ComponentName name) { 551 final ResolverTarget target = mTargetsDict.get(name); 552 if (target != null) { 553 return target.getSelectProbability(); 554 } 555 return 0; 556 } 557 558 @Override notifyOnTargetSelected(ComponentName componentName)559 public void notifyOnTargetSelected(ComponentName componentName) { 560 if (mRanker != null) { 561 try { 562 int selectedPos = new ArrayList<ComponentName>(mTargetsDict.keySet()) 563 .indexOf(componentName); 564 if (selectedPos >= 0 && mTargets != null) { 565 final float selectedProbability = getScore(componentName); 566 int order = 0; 567 for (ResolverTarget target : mTargets) { 568 if (target.getSelectProbability() > selectedProbability) { 569 order++; 570 } 571 } 572 logMetrics(order); 573 mRanker.train(mTargets, selectedPos); 574 } else { 575 if (DEBUG) { 576 Log.d(TAG, "Selected a unknown component: " + componentName); 577 } 578 } 579 } catch (RemoteException e) { 580 Log.e(TAG, "Error in Train: " + e); 581 } 582 } else { 583 if (DEBUG) { 584 Log.d(TAG, "Ranker is null; skip updateModel."); 585 } 586 } 587 } 588 589 /** Records metrics for evaluation. */ logMetrics(int selectedPos)590 private void logMetrics(int selectedPos) { 591 if (mRankerServiceName != null) { 592 MetricsLogger metricsLogger = new MetricsLogger(); 593 LogMaker log = new LogMaker(MetricsEvent.ACTION_TARGET_SELECTED); 594 log.setComponentName(mRankerServiceName); 595 log.addTaggedData(MetricsEvent.FIELD_IS_CATEGORY_USED, mAnnotationsUsed ? 1 : 0); 596 log.addTaggedData(MetricsEvent.FIELD_RANKED_POSITION, selectedPos); 597 metricsLogger.write(log); 598 } 599 } 600 } 601 } 602