• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.android.tradefed.invoker.shard;
17 
18 import com.android.annotations.VisibleForTesting;
19 import com.android.tradefed.config.IConfiguration;
20 import com.android.tradefed.error.HarnessRuntimeException;
21 import com.android.tradefed.invoker.IRescheduler;
22 import com.android.tradefed.invoker.TestInformation;
23 import com.android.tradefed.log.ITestLogger;
24 import com.android.tradefed.log.LogUtil.CLog;
25 import com.android.tradefed.result.ITestLoggerReceiver;
26 import com.android.tradefed.result.error.InfraErrorIdentifier;
27 import com.android.tradefed.testtype.IBuildReceiver;
28 import com.android.tradefed.testtype.IDeviceTest;
29 import com.android.tradefed.testtype.IInvocationContextReceiver;
30 import com.android.tradefed.testtype.IRemoteTest;
31 import com.android.tradefed.testtype.IRuntimeHintProvider;
32 import com.android.tradefed.testtype.IShardableTest;
33 import com.android.tradefed.testtype.suite.ITestSuite;
34 import com.android.tradefed.testtype.suite.ModuleMerger;
35 import com.android.tradefed.util.TimeUtil;
36 
37 import java.util.ArrayList;
38 import java.util.Collection;
39 import java.util.Collections;
40 import java.util.Comparator;
41 import java.util.List;
42 import java.util.regex.Matcher;
43 import java.util.regex.Pattern;
44 
45 /** Sharding strategy to create strict shards that do not report together, */
46 public class StrictShardHelper extends ShardHelper {
47 
48     /** {@inheritDoc} */
49     @Override
shardConfig( IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger)50     public boolean shardConfig(
51             IConfiguration config,
52             TestInformation testInfo,
53             IRescheduler rescheduler,
54             ITestLogger logger) {
55         if (config.getCommandOptions().shouldRemoteDynamicShard()) {
56             return shardConfigDynamic(config, testInfo, rescheduler, logger);
57         } else {
58             return shardConfigInternal(config, testInfo, rescheduler, logger);
59         }
60     }
61 
62     @VisibleForTesting
shardConfigDynamic( IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger)63     protected boolean shardConfigDynamic(
64             IConfiguration config,
65             TestInformation testInfo,
66             IRescheduler rescheduler,
67             ITestLogger logger) {
68         // attempt dynamic sharding
69         // may call #shardConfigInternal itself if preconditions are not met
70         DynamicShardHelper helper = new DynamicShardHelper();
71         return helper.shardConfig(config, testInfo, rescheduler, logger);
72     }
73 
shardConfigInternal( IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger)74     protected boolean shardConfigInternal(
75             IConfiguration config,
76             TestInformation testInfo,
77             IRescheduler rescheduler,
78             ITestLogger logger) {
79         Integer shardCount = config.getCommandOptions().getShardCount();
80         Integer shardIndex = config.getCommandOptions().getShardIndex();
81         boolean optimizeMainline = config.getCommandOptions().getOptimizeMainlineTest();
82 
83         if (shardIndex == null) {
84             return super.shardConfig(config, testInfo, rescheduler, logger);
85         }
86         if (shardCount == null) {
87             throw new RuntimeException("shard-count is null while shard-index is " + shardIndex);
88         }
89         // No sharding needed if shard-count=1
90         if (shardCount == 1) {
91             return false;
92         }
93 
94         // Split tests in place, without actually sharding.
95         List<IRemoteTest> listAllTests = getAllTests(config, shardCount, testInfo, logger);
96         // We cannot shuffle to get better average results
97         normalizeDistribution(listAllTests, shardCount);
98         List<IRemoteTest> splitList;
99         if (shardCount == 1) {
100             // not sharded
101             splitList = listAllTests;
102         } else {
103             splitList =
104                     splitTests(
105                                     listAllTests,
106                                     shardCount,
107                                     config.getCommandOptions().shouldUseEvenModuleSharding())
108                             .get(shardIndex);
109         }
110         aggregateSuiteModules(splitList);
111         if (optimizeMainline) {
112             CLog.i("Reordering the test modules list for index: %s", shardIndex);
113             reorderTestModules(splitList);
114         }
115         config.setTests(splitList);
116         return false;
117     }
118 
119     /**
120      * Helper to re order the list full list of {@link IRemoteTest} for mainline.
121      *
122      * @param tests the {@link IRemoteTest} containing all the tests that need to run.
123      */
reorderTestModules(List<IRemoteTest> tests)124     private void reorderTestModules(List<IRemoteTest> tests) {
125         Collections.sort(
126                 tests,
127                 new Comparator<IRemoteTest>() {
128                     @Override
129                     public int compare(IRemoteTest o1, IRemoteTest o2) {
130                         String moduleId1 = ((ITestSuite) o1).getDirectModule().getId();
131                         String moduleId2 = ((ITestSuite) o2).getDirectModule().getId();
132                         return getMainlineId(moduleId1).compareTo(getMainlineId(moduleId2));
133                     }
134                 });
135     }
136 
137     /**
138      * Returns the parameterized mainline modules' name defined in the square brackets.
139      *
140      * @param id The module's name.
141      * @throws RuntimeException if the module name doesn't match the pattern for mainline modules.
142      */
getMainlineId(String id)143     private String getMainlineId(String id) {
144         // Pattern used to identify the parameterized mainline modules defined in the square
145         // brackets.
146         Pattern parameterizedMainlineRegex = Pattern.compile("\\[(.*(\\.apk|.apex|.apks))\\]$");
147         Matcher m = parameterizedMainlineRegex.matcher(id);
148         if (m.find()) {
149             return m.group(1);
150         }
151         throw new HarnessRuntimeException(
152                 String.format(
153                         "Module: %s doesn't match the pattern for mainline modules. The "
154                                 + "pattern should end with apk/apex/apks.",
155                         id),
156                 InfraErrorIdentifier.OPTION_CONFIGURATION_ERROR);
157     }
158 
159     /**
160      * Helper to return the full list of {@link IRemoteTest} based on {@link IShardableTest} split.
161      *
162      * @param config the {@link IConfiguration} describing the invocation.
163      * @param shardCount the shard count hint to be provided to some tests.
164      * @param testInfo the {@link TestInformation} of the parent invocation.
165      * @return the list of all {@link IRemoteTest}.
166      */
getAllTests( IConfiguration config, Integer shardCount, TestInformation testInfo, ITestLogger logger)167     private List<IRemoteTest> getAllTests(
168             IConfiguration config,
169             Integer shardCount,
170             TestInformation testInfo,
171             ITestLogger logger) {
172         List<IRemoteTest> allTests = new ArrayList<>();
173         for (IRemoteTest test : config.getTests()) {
174             if (test instanceof IShardableTest) {
175                 // Inject current information to help with sharding
176                 if (test instanceof IBuildReceiver) {
177                     ((IBuildReceiver) test).setBuild(testInfo.getBuildInfo());
178                 }
179                 if (test instanceof IDeviceTest) {
180                     ((IDeviceTest) test).setDevice(testInfo.getDevice());
181                 }
182                 if (test instanceof IInvocationContextReceiver) {
183                     ((IInvocationContextReceiver) test).setInvocationContext(testInfo.getContext());
184                 }
185                 if (test instanceof ITestLoggerReceiver) {
186                     ((ITestLoggerReceiver) test).setTestLogger(logger);
187                 }
188 
189                 // Handling of the ITestSuite is a special case, we do not allow pool of tests
190                 // since each shard needs to be independent.
191                 if (test instanceof ITestSuite) {
192                     ((ITestSuite) test).setShouldMakeDynamicModule(false);
193                 }
194 
195                 Collection<IRemoteTest> subTests =
196                         ((IShardableTest) test).split(shardCount, testInfo);
197                 if (subTests == null) {
198                     // test did not shard so we add it as is.
199                     allTests.add(test);
200                 } else {
201                     allTests.addAll(subTests);
202                 }
203             } else {
204                 // if test is not shardable we add it as is.
205                 allTests.add(test);
206             }
207         }
208         return allTests;
209     }
210 
211     /**
212      * Split the list of tests to run however the implementation see fit. Sharding needs to be
213      * consistent. It is acceptable to return an empty list if no tests can be run in the shard.
214      *
215      * <p>Implement this in order to provide a test suite specific sharding. The default
216      * implementation attempts to balance the number of IRemoteTest per shards as much as possible
217      * as a first step, then use a minor criteria or run-hint to adjust the lists a bit more.
218      *
219      * @param fullList the initial full list of {@link IRemoteTest} containing all the tests that
220      *     need to run.
221      * @param shardCount the total number of shard that need to run.
222      * @param useEvenModuleSharding whether to use a strategy that evenly distributes number of
223      *     modules across shards
224      * @return a list of list {@link IRemoteTest}s that have been assigned to each shard. The list
225      *     size will be the shardCount.
226      */
227     @VisibleForTesting
splitTests( List<IRemoteTest> fullList, int shardCount, boolean useEvenModuleSharding)228     protected List<List<IRemoteTest>> splitTests(
229             List<IRemoteTest> fullList, int shardCount, boolean useEvenModuleSharding) {
230         List<List<IRemoteTest>> shards;
231         if (useEvenModuleSharding) {
232             CLog.d("Using the sharding strategy to distribute number of modules more evenly.");
233             shards = shardList(fullList, shardCount);
234         } else {
235             shards = new ArrayList<>();
236             // We are using Match.ceil to avoid the last shard having too much extra.
237             int numPerShard = (int) Math.ceil(fullList.size() / (float) shardCount);
238 
239             boolean needsCorrection = false;
240             float correctionRatio = 0f;
241             if (fullList.size() > shardCount) {
242                 // In some cases because of the Math.ceil, some combination might run out of tests
243                 // before the last shard, in that case we populate a correction to rebalance the
244                 // tests.
245                 needsCorrection = (numPerShard * (shardCount - 1)) > fullList.size();
246                 correctionRatio = numPerShard - (fullList.size() / (float) shardCount);
247             }
248             // Recalculate the number of tests per shard with the correction taken into account.
249             numPerShard = (int) Math.floor(numPerShard - correctionRatio);
250             // Based of the parameters, distribute the tests across shards.
251             shards = balancedDistrib(fullList, shardCount, numPerShard, needsCorrection);
252         }
253         // Do last minute rebalancing
254         topBottom(shards, shardCount);
255         return shards;
256     }
257 
balancedDistrib( List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection)258     private List<List<IRemoteTest>> balancedDistrib(
259             List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection) {
260         List<List<IRemoteTest>> shards = new ArrayList<>();
261         List<IRemoteTest> correctionList = new ArrayList<>();
262         int correctionSize = 0;
263 
264         // Generate all the shards
265         for (int i = 0; i < shardCount; i++) {
266             List<IRemoteTest> shardList;
267             if (i >= fullList.size()) {
268                 // Return empty list when we don't have enough tests for all the shards.
269                 shardList = new ArrayList<IRemoteTest>();
270                 shards.add(shardList);
271                 continue;
272             }
273 
274             if (i == shardCount - 1) {
275                 // last shard take everything remaining except the correction:
276                 if (needsCorrection) {
277                     // We omit the size of the correction needed.
278                     correctionSize = fullList.size() - (numPerShard + (i * numPerShard));
279                     correctionList =
280                             fullList.subList(fullList.size() - correctionSize, fullList.size());
281                 }
282                 shardList = fullList.subList(i * numPerShard, fullList.size() - correctionSize);
283                 shards.add(new ArrayList<>(shardList));
284                 continue;
285             }
286             shardList = fullList.subList(i * numPerShard, numPerShard + (i * numPerShard));
287             shards.add(new ArrayList<>(shardList));
288         }
289 
290         // If we have correction omitted tests, disperse them on each shard, at this point the
291         // number of tests in correction is ensured to be bellow the number of shards.
292         for (int i = 0; i < shardCount; i++) {
293             if (i < correctionList.size()) {
294                 shards.get(i).add(correctionList.get(i));
295             } else {
296                 break;
297             }
298         }
299         return shards;
300     }
301 
302     @VisibleForTesting
shardList(List<T> fullList, int shardCount)303     static <T> List<List<T>> shardList(List<T> fullList, int shardCount) {
304         int totalSize = fullList.size();
305         int smallShardSize = totalSize / shardCount;
306         int bigShardSize = smallShardSize + 1;
307         int bigShardCount = totalSize % shardCount;
308 
309         // Correctness:
310         // sum(shard sizes)
311         // == smallShardSize * smallShardCount + bigShardSize * bigShardCount
312         // == smallShardSize * (shardCount - bigShardCount) + bigShardSize * bigShardCount
313         // == smallShardSize * (shardCount - bigShardCount) + (smallShardSize + 1) * bigShardCount
314         // == smallShardSize * (shardCount - bigShardCount + bigShardCount) + bigShardCount
315         // == smallShardSize * shardCount + bigShardCount
316         // == floor(totalSize / shardCount) * shardCount + remainder(totalSize / shardCount)
317         // == totalSize
318 
319         List<List<T>> shards = new ArrayList<>();
320         int i = 0;
321         for (; i < bigShardCount * bigShardSize; i += bigShardSize) {
322             shards.add(fullList.subList(i, i + bigShardSize));
323         }
324         for (; i < totalSize; i += smallShardSize) {
325             shards.add(fullList.subList(i, i + smallShardSize));
326         }
327         while (shards.size() < shardCount) {
328             shards.add(new ArrayList<>());
329         }
330         return shards;
331     }
332 
333     /**
334      * Move around predictably the tests in order to have a better uniformization of the tests in
335      * each shard.
336      */
normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount)337     private void normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount) {
338         final int numRound = shardCount;
339         final int distance = shardCount - 1;
340         for (int i = 0; i < numRound; i++) {
341             for (int j = 0; j < listAllTests.size(); j = j + distance) {
342                 // Push the test at the end
343                 IRemoteTest push = listAllTests.remove(j);
344                 listAllTests.add(push);
345             }
346         }
347     }
348 
349     /**
350      * Special handling for suite from {@link ITestSuite}. We aggregate the tests in the same shard
351      * in order to optimize target_preparation step.
352      *
353      * @param tests the {@link List} of {@link IRemoteTest} for that shard.
354      */
aggregateSuiteModules(List<IRemoteTest> tests)355     private void aggregateSuiteModules(List<IRemoteTest> tests) {
356         List<IRemoteTest> dupList = new ArrayList<>(tests);
357         for (int i = 0; i < dupList.size(); i++) {
358             if (dupList.get(i) instanceof ITestSuite) {
359                 // We iterate the other tests to see if we can find another from the same module.
360                 for (int j = i + 1; j < dupList.size(); j++) {
361                     // If the test was not already merged
362                     if (tests.contains(dupList.get(j))) {
363                         if (dupList.get(j) instanceof ITestSuite) {
364                             if (ModuleMerger.arePartOfSameSuite(
365                                     (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j))) {
366                                 ModuleMerger.mergeSplittedITestSuite(
367                                         (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j));
368                                 tests.remove(dupList.get(j));
369                             }
370                         }
371                     }
372                 }
373             }
374         }
375     }
376 
topBottom(List<List<IRemoteTest>> allShards, int shardCount)377     private void topBottom(List<List<IRemoteTest>> allShards, int shardCount) {
378         // Generate approximate RuntimeHint for each shard
379         int index = 0;
380         List<SortShardObj> shardTimes = new ArrayList<>();
381         for (List<IRemoteTest> shard : allShards) {
382             long aggTime = 0L;
383             CLog.d("++++++++++++++++++ SHARD %s +++++++++++++++", index);
384             for (IRemoteTest test : shard) {
385                 if (test instanceof IRuntimeHintProvider) {
386                     aggTime += ((IRuntimeHintProvider) test).getRuntimeHint();
387                 }
388             }
389             CLog.d("Shard %s approximate time: %s", index, TimeUtil.formatElapsedTime(aggTime));
390             shardTimes.add(new SortShardObj(index, aggTime));
391             index++;
392             CLog.d("+++++++++++++++++++++++++++++++++++++++++++");
393         }
394         // We only attempt this when the number of shard is pretty high
395         if (shardCount < 4) {
396             return;
397         }
398         Collections.sort(shardTimes);
399         if ((shardTimes.get(0).mAggTime - shardTimes.get(shardTimes.size() - 1).mAggTime)
400                 < 60 * 60 * 1000L) {
401             return;
402         }
403 
404         // take 30% top shard (10 shard = top 3 shards)
405         for (int i = 0; i < (shardCount * 0.3); i++) {
406             CLog.d(
407                     "Top shard %s is index %s with %s",
408                     i,
409                     shardTimes.get(i).mIndex,
410                     TimeUtil.formatElapsedTime(shardTimes.get(i).mAggTime));
411             int give = shardTimes.get(i).mIndex;
412             int receive = shardTimes.get(shardTimes.size() - 1 - i).mIndex;
413             CLog.d("Giving from shard %s to shard %s", give, receive);
414             for (int j = 0; j < (allShards.get(give).size() * (0.2f / (i + 1))); j++) {
415                 IRemoteTest givetest = allShards.get(give).remove(0);
416                 allShards.get(receive).add(givetest);
417             }
418         }
419     }
420 
421     /** Object holder for shard, their index and their aggregated execution time. */
422     private class SortShardObj implements Comparable<SortShardObj> {
423         public final int mIndex;
424         public final Long mAggTime;
425 
SortShardObj(int index, long aggTime)426         public SortShardObj(int index, long aggTime) {
427             mIndex = index;
428             mAggTime = aggTime;
429         }
430 
431         @Override
compareTo(SortShardObj obj)432         public int compareTo(SortShardObj obj) {
433             return obj.mAggTime.compareTo(mAggTime);
434         }
435     }
436 }
437