1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package com.android.tradefed.invoker.shard; 17 18 import com.android.annotations.VisibleForTesting; 19 import com.android.tradefed.config.IConfiguration; 20 import com.android.tradefed.error.HarnessRuntimeException; 21 import com.android.tradefed.invoker.IRescheduler; 22 import com.android.tradefed.invoker.TestInformation; 23 import com.android.tradefed.log.ITestLogger; 24 import com.android.tradefed.log.LogUtil.CLog; 25 import com.android.tradefed.result.ITestLoggerReceiver; 26 import com.android.tradefed.result.error.InfraErrorIdentifier; 27 import com.android.tradefed.testtype.IBuildReceiver; 28 import com.android.tradefed.testtype.IDeviceTest; 29 import com.android.tradefed.testtype.IInvocationContextReceiver; 30 import com.android.tradefed.testtype.IRemoteTest; 31 import com.android.tradefed.testtype.IRuntimeHintProvider; 32 import com.android.tradefed.testtype.IShardableTest; 33 import com.android.tradefed.testtype.suite.ITestSuite; 34 import com.android.tradefed.testtype.suite.ModuleMerger; 35 import com.android.tradefed.util.TimeUtil; 36 37 import java.util.ArrayList; 38 import java.util.Collection; 39 import java.util.Collections; 40 import java.util.Comparator; 41 import java.util.List; 42 import java.util.regex.Matcher; 43 import java.util.regex.Pattern; 44 45 /** Sharding strategy to create strict shards that do not report together, */ 46 public class StrictShardHelper extends ShardHelper { 47 48 /** {@inheritDoc} */ 49 @Override shardConfig( IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger)50 public boolean shardConfig( 51 IConfiguration config, 52 TestInformation testInfo, 53 IRescheduler rescheduler, 54 ITestLogger logger) { 55 if (config.getCommandOptions().shouldRemoteDynamicShard()) { 56 return shardConfigDynamic(config, testInfo, rescheduler, logger); 57 } else { 58 return shardConfigInternal(config, testInfo, rescheduler, logger); 59 } 60 } 61 62 @VisibleForTesting shardConfigDynamic( IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger)63 protected boolean shardConfigDynamic( 64 IConfiguration config, 65 TestInformation testInfo, 66 IRescheduler rescheduler, 67 ITestLogger logger) { 68 // attempt dynamic sharding 69 // may call #shardConfigInternal itself if preconditions are not met 70 DynamicShardHelper helper = new DynamicShardHelper(); 71 return helper.shardConfig(config, testInfo, rescheduler, logger); 72 } 73 shardConfigInternal( IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger)74 protected boolean shardConfigInternal( 75 IConfiguration config, 76 TestInformation testInfo, 77 IRescheduler rescheduler, 78 ITestLogger logger) { 79 Integer shardCount = config.getCommandOptions().getShardCount(); 80 Integer shardIndex = config.getCommandOptions().getShardIndex(); 81 boolean optimizeMainline = config.getCommandOptions().getOptimizeMainlineTest(); 82 83 if (shardIndex == null) { 84 return super.shardConfig(config, testInfo, rescheduler, logger); 85 } 86 if (shardCount == null) { 87 throw new RuntimeException("shard-count is null while shard-index is " + shardIndex); 88 } 89 // No sharding needed if shard-count=1 90 if (shardCount == 1) { 91 return false; 92 } 93 94 // Split tests in place, without actually sharding. 95 List<IRemoteTest> listAllTests = getAllTests(config, shardCount, testInfo, logger); 96 // We cannot shuffle to get better average results 97 normalizeDistribution(listAllTests, shardCount); 98 List<IRemoteTest> splitList; 99 if (shardCount == 1) { 100 // not sharded 101 splitList = listAllTests; 102 } else { 103 splitList = 104 splitTests( 105 listAllTests, 106 shardCount, 107 config.getCommandOptions().shouldUseEvenModuleSharding()) 108 .get(shardIndex); 109 } 110 aggregateSuiteModules(splitList); 111 if (optimizeMainline) { 112 CLog.i("Reordering the test modules list for index: %s", shardIndex); 113 reorderTestModules(splitList); 114 } 115 config.setTests(splitList); 116 return false; 117 } 118 119 /** 120 * Helper to re order the list full list of {@link IRemoteTest} for mainline. 121 * 122 * @param tests the {@link IRemoteTest} containing all the tests that need to run. 123 */ reorderTestModules(List<IRemoteTest> tests)124 private void reorderTestModules(List<IRemoteTest> tests) { 125 Collections.sort( 126 tests, 127 new Comparator<IRemoteTest>() { 128 @Override 129 public int compare(IRemoteTest o1, IRemoteTest o2) { 130 String moduleId1 = ((ITestSuite) o1).getDirectModule().getId(); 131 String moduleId2 = ((ITestSuite) o2).getDirectModule().getId(); 132 return getMainlineId(moduleId1).compareTo(getMainlineId(moduleId2)); 133 } 134 }); 135 } 136 137 /** 138 * Returns the parameterized mainline modules' name defined in the square brackets. 139 * 140 * @param id The module's name. 141 * @throws RuntimeException if the module name doesn't match the pattern for mainline modules. 142 */ getMainlineId(String id)143 private String getMainlineId(String id) { 144 // Pattern used to identify the parameterized mainline modules defined in the square 145 // brackets. 146 Pattern parameterizedMainlineRegex = Pattern.compile("\\[(.*(\\.apk|.apex|.apks))\\]$"); 147 Matcher m = parameterizedMainlineRegex.matcher(id); 148 if (m.find()) { 149 return m.group(1); 150 } 151 throw new HarnessRuntimeException( 152 String.format( 153 "Module: %s doesn't match the pattern for mainline modules. The " 154 + "pattern should end with apk/apex/apks.", 155 id), 156 InfraErrorIdentifier.OPTION_CONFIGURATION_ERROR); 157 } 158 159 /** 160 * Helper to return the full list of {@link IRemoteTest} based on {@link IShardableTest} split. 161 * 162 * @param config the {@link IConfiguration} describing the invocation. 163 * @param shardCount the shard count hint to be provided to some tests. 164 * @param testInfo the {@link TestInformation} of the parent invocation. 165 * @return the list of all {@link IRemoteTest}. 166 */ getAllTests( IConfiguration config, Integer shardCount, TestInformation testInfo, ITestLogger logger)167 private List<IRemoteTest> getAllTests( 168 IConfiguration config, 169 Integer shardCount, 170 TestInformation testInfo, 171 ITestLogger logger) { 172 List<IRemoteTest> allTests = new ArrayList<>(); 173 for (IRemoteTest test : config.getTests()) { 174 if (test instanceof IShardableTest) { 175 // Inject current information to help with sharding 176 if (test instanceof IBuildReceiver) { 177 ((IBuildReceiver) test).setBuild(testInfo.getBuildInfo()); 178 } 179 if (test instanceof IDeviceTest) { 180 ((IDeviceTest) test).setDevice(testInfo.getDevice()); 181 } 182 if (test instanceof IInvocationContextReceiver) { 183 ((IInvocationContextReceiver) test).setInvocationContext(testInfo.getContext()); 184 } 185 if (test instanceof ITestLoggerReceiver) { 186 ((ITestLoggerReceiver) test).setTestLogger(logger); 187 } 188 189 // Handling of the ITestSuite is a special case, we do not allow pool of tests 190 // since each shard needs to be independent. 191 if (test instanceof ITestSuite) { 192 ((ITestSuite) test).setShouldMakeDynamicModule(false); 193 } 194 195 Collection<IRemoteTest> subTests = 196 ((IShardableTest) test).split(shardCount, testInfo); 197 if (subTests == null) { 198 // test did not shard so we add it as is. 199 allTests.add(test); 200 } else { 201 allTests.addAll(subTests); 202 } 203 } else { 204 // if test is not shardable we add it as is. 205 allTests.add(test); 206 } 207 } 208 return allTests; 209 } 210 211 /** 212 * Split the list of tests to run however the implementation see fit. Sharding needs to be 213 * consistent. It is acceptable to return an empty list if no tests can be run in the shard. 214 * 215 * <p>Implement this in order to provide a test suite specific sharding. The default 216 * implementation attempts to balance the number of IRemoteTest per shards as much as possible 217 * as a first step, then use a minor criteria or run-hint to adjust the lists a bit more. 218 * 219 * @param fullList the initial full list of {@link IRemoteTest} containing all the tests that 220 * need to run. 221 * @param shardCount the total number of shard that need to run. 222 * @param useEvenModuleSharding whether to use a strategy that evenly distributes number of 223 * modules across shards 224 * @return a list of list {@link IRemoteTest}s that have been assigned to each shard. The list 225 * size will be the shardCount. 226 */ 227 @VisibleForTesting splitTests( List<IRemoteTest> fullList, int shardCount, boolean useEvenModuleSharding)228 protected List<List<IRemoteTest>> splitTests( 229 List<IRemoteTest> fullList, int shardCount, boolean useEvenModuleSharding) { 230 List<List<IRemoteTest>> shards; 231 if (useEvenModuleSharding) { 232 CLog.d("Using the sharding strategy to distribute number of modules more evenly."); 233 shards = shardList(fullList, shardCount); 234 } else { 235 shards = new ArrayList<>(); 236 // We are using Match.ceil to avoid the last shard having too much extra. 237 int numPerShard = (int) Math.ceil(fullList.size() / (float) shardCount); 238 239 boolean needsCorrection = false; 240 float correctionRatio = 0f; 241 if (fullList.size() > shardCount) { 242 // In some cases because of the Math.ceil, some combination might run out of tests 243 // before the last shard, in that case we populate a correction to rebalance the 244 // tests. 245 needsCorrection = (numPerShard * (shardCount - 1)) > fullList.size(); 246 correctionRatio = numPerShard - (fullList.size() / (float) shardCount); 247 } 248 // Recalculate the number of tests per shard with the correction taken into account. 249 numPerShard = (int) Math.floor(numPerShard - correctionRatio); 250 // Based of the parameters, distribute the tests across shards. 251 shards = balancedDistrib(fullList, shardCount, numPerShard, needsCorrection); 252 } 253 // Do last minute rebalancing 254 topBottom(shards, shardCount); 255 return shards; 256 } 257 balancedDistrib( List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection)258 private List<List<IRemoteTest>> balancedDistrib( 259 List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection) { 260 List<List<IRemoteTest>> shards = new ArrayList<>(); 261 List<IRemoteTest> correctionList = new ArrayList<>(); 262 int correctionSize = 0; 263 264 // Generate all the shards 265 for (int i = 0; i < shardCount; i++) { 266 List<IRemoteTest> shardList; 267 if (i >= fullList.size()) { 268 // Return empty list when we don't have enough tests for all the shards. 269 shardList = new ArrayList<IRemoteTest>(); 270 shards.add(shardList); 271 continue; 272 } 273 274 if (i == shardCount - 1) { 275 // last shard take everything remaining except the correction: 276 if (needsCorrection) { 277 // We omit the size of the correction needed. 278 correctionSize = fullList.size() - (numPerShard + (i * numPerShard)); 279 correctionList = 280 fullList.subList(fullList.size() - correctionSize, fullList.size()); 281 } 282 shardList = fullList.subList(i * numPerShard, fullList.size() - correctionSize); 283 shards.add(new ArrayList<>(shardList)); 284 continue; 285 } 286 shardList = fullList.subList(i * numPerShard, numPerShard + (i * numPerShard)); 287 shards.add(new ArrayList<>(shardList)); 288 } 289 290 // If we have correction omitted tests, disperse them on each shard, at this point the 291 // number of tests in correction is ensured to be bellow the number of shards. 292 for (int i = 0; i < shardCount; i++) { 293 if (i < correctionList.size()) { 294 shards.get(i).add(correctionList.get(i)); 295 } else { 296 break; 297 } 298 } 299 return shards; 300 } 301 302 @VisibleForTesting shardList(List<T> fullList, int shardCount)303 static <T> List<List<T>> shardList(List<T> fullList, int shardCount) { 304 int totalSize = fullList.size(); 305 int smallShardSize = totalSize / shardCount; 306 int bigShardSize = smallShardSize + 1; 307 int bigShardCount = totalSize % shardCount; 308 309 // Correctness: 310 // sum(shard sizes) 311 // == smallShardSize * smallShardCount + bigShardSize * bigShardCount 312 // == smallShardSize * (shardCount - bigShardCount) + bigShardSize * bigShardCount 313 // == smallShardSize * (shardCount - bigShardCount) + (smallShardSize + 1) * bigShardCount 314 // == smallShardSize * (shardCount - bigShardCount + bigShardCount) + bigShardCount 315 // == smallShardSize * shardCount + bigShardCount 316 // == floor(totalSize / shardCount) * shardCount + remainder(totalSize / shardCount) 317 // == totalSize 318 319 List<List<T>> shards = new ArrayList<>(); 320 int i = 0; 321 for (; i < bigShardCount * bigShardSize; i += bigShardSize) { 322 shards.add(fullList.subList(i, i + bigShardSize)); 323 } 324 for (; i < totalSize; i += smallShardSize) { 325 shards.add(fullList.subList(i, i + smallShardSize)); 326 } 327 while (shards.size() < shardCount) { 328 shards.add(new ArrayList<>()); 329 } 330 return shards; 331 } 332 333 /** 334 * Move around predictably the tests in order to have a better uniformization of the tests in 335 * each shard. 336 */ normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount)337 private void normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount) { 338 final int numRound = shardCount; 339 final int distance = shardCount - 1; 340 for (int i = 0; i < numRound; i++) { 341 for (int j = 0; j < listAllTests.size(); j = j + distance) { 342 // Push the test at the end 343 IRemoteTest push = listAllTests.remove(j); 344 listAllTests.add(push); 345 } 346 } 347 } 348 349 /** 350 * Special handling for suite from {@link ITestSuite}. We aggregate the tests in the same shard 351 * in order to optimize target_preparation step. 352 * 353 * @param tests the {@link List} of {@link IRemoteTest} for that shard. 354 */ aggregateSuiteModules(List<IRemoteTest> tests)355 private void aggregateSuiteModules(List<IRemoteTest> tests) { 356 List<IRemoteTest> dupList = new ArrayList<>(tests); 357 for (int i = 0; i < dupList.size(); i++) { 358 if (dupList.get(i) instanceof ITestSuite) { 359 // We iterate the other tests to see if we can find another from the same module. 360 for (int j = i + 1; j < dupList.size(); j++) { 361 // If the test was not already merged 362 if (tests.contains(dupList.get(j))) { 363 if (dupList.get(j) instanceof ITestSuite) { 364 if (ModuleMerger.arePartOfSameSuite( 365 (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j))) { 366 ModuleMerger.mergeSplittedITestSuite( 367 (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j)); 368 tests.remove(dupList.get(j)); 369 } 370 } 371 } 372 } 373 } 374 } 375 } 376 topBottom(List<List<IRemoteTest>> allShards, int shardCount)377 private void topBottom(List<List<IRemoteTest>> allShards, int shardCount) { 378 // Generate approximate RuntimeHint for each shard 379 int index = 0; 380 List<SortShardObj> shardTimes = new ArrayList<>(); 381 for (List<IRemoteTest> shard : allShards) { 382 long aggTime = 0L; 383 CLog.d("++++++++++++++++++ SHARD %s +++++++++++++++", index); 384 for (IRemoteTest test : shard) { 385 if (test instanceof IRuntimeHintProvider) { 386 aggTime += ((IRuntimeHintProvider) test).getRuntimeHint(); 387 } 388 } 389 CLog.d("Shard %s approximate time: %s", index, TimeUtil.formatElapsedTime(aggTime)); 390 shardTimes.add(new SortShardObj(index, aggTime)); 391 index++; 392 CLog.d("+++++++++++++++++++++++++++++++++++++++++++"); 393 } 394 // We only attempt this when the number of shard is pretty high 395 if (shardCount < 4) { 396 return; 397 } 398 Collections.sort(shardTimes); 399 if ((shardTimes.get(0).mAggTime - shardTimes.get(shardTimes.size() - 1).mAggTime) 400 < 60 * 60 * 1000L) { 401 return; 402 } 403 404 // take 30% top shard (10 shard = top 3 shards) 405 for (int i = 0; i < (shardCount * 0.3); i++) { 406 CLog.d( 407 "Top shard %s is index %s with %s", 408 i, 409 shardTimes.get(i).mIndex, 410 TimeUtil.formatElapsedTime(shardTimes.get(i).mAggTime)); 411 int give = shardTimes.get(i).mIndex; 412 int receive = shardTimes.get(shardTimes.size() - 1 - i).mIndex; 413 CLog.d("Giving from shard %s to shard %s", give, receive); 414 for (int j = 0; j < (allShards.get(give).size() * (0.2f / (i + 1))); j++) { 415 IRemoteTest givetest = allShards.get(give).remove(0); 416 allShards.get(receive).add(givetest); 417 } 418 } 419 } 420 421 /** Object holder for shard, their index and their aggregated execution time. */ 422 private class SortShardObj implements Comparable<SortShardObj> { 423 public final int mIndex; 424 public final Long mAggTime; 425 SortShardObj(int index, long aggTime)426 public SortShardObj(int index, long aggTime) { 427 mIndex = index; 428 mAggTime = aggTime; 429 } 430 431 @Override compareTo(SortShardObj obj)432 public int compareTo(SortShardObj obj) { 433 return obj.mAggTime.compareTo(mAggTime); 434 } 435 } 436 } 437