• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.android.tradefed.invoker.shard;
17 
18 import com.android.annotations.VisibleForTesting;
19 import com.android.tradefed.config.IConfiguration;
20 import com.android.tradefed.invoker.IInvocationContext;
21 import com.android.tradefed.invoker.IRescheduler;
22 import com.android.tradefed.log.LogUtil.CLog;
23 import com.android.tradefed.testtype.IBuildReceiver;
24 import com.android.tradefed.testtype.IDeviceTest;
25 import com.android.tradefed.testtype.IInvocationContextReceiver;
26 import com.android.tradefed.testtype.IMultiDeviceTest;
27 import com.android.tradefed.testtype.IRemoteTest;
28 import com.android.tradefed.testtype.IRuntimeHintProvider;
29 import com.android.tradefed.testtype.IShardableTest;
30 import com.android.tradefed.testtype.suite.ITestSuite;
31 import com.android.tradefed.testtype.suite.ModuleMerger;
32 import com.android.tradefed.util.TimeUtil;
33 
34 import java.util.ArrayList;
35 import java.util.Collection;
36 import java.util.Collections;
37 import java.util.List;
38 
39 /** Sharding strategy to create strict shards that do not report together, */
40 public class StrictShardHelper extends ShardHelper {
41 
42     /** {@inheritDoc} */
43     @Override
shardConfig( IConfiguration config, IInvocationContext context, IRescheduler rescheduler)44     public boolean shardConfig(
45             IConfiguration config, IInvocationContext context, IRescheduler rescheduler) {
46         Integer shardCount = config.getCommandOptions().getShardCount();
47         Integer shardIndex = config.getCommandOptions().getShardIndex();
48 
49         if (shardIndex == null) {
50             return super.shardConfig(config, context, rescheduler);
51         }
52         if (shardCount == null) {
53             throw new RuntimeException("shard-count is null while shard-index is " + shardIndex);
54         }
55 
56         // Split tests in place, without actually sharding.
57         List<IRemoteTest> listAllTests = getAllTests(config, shardCount, context);
58         // We cannot shuffle to get better average results
59         normalizeDistribution(listAllTests, shardCount);
60         List<IRemoteTest> splitList;
61         if (shardCount == 1) {
62             // not sharded
63             splitList = listAllTests;
64         } else {
65             splitList = splitTests(listAllTests, shardCount).get(shardIndex);
66         }
67         aggregateSuiteModules(splitList);
68         config.setTests(splitList);
69         return false;
70     }
71 
72     /**
73      * Helper to return the full list of {@link IRemoteTest} based on {@link IShardableTest} split.
74      *
75      * @param config the {@link IConfiguration} describing the invocation.
76      * @param shardCount the shard count hint to be provided to some tests.
77      * @param context the {@link IInvocationContext} of the parent invocation.
78      * @return the list of all {@link IRemoteTest}.
79      */
getAllTests( IConfiguration config, Integer shardCount, IInvocationContext context)80     private List<IRemoteTest> getAllTests(
81             IConfiguration config, Integer shardCount, IInvocationContext context) {
82         List<IRemoteTest> allTests = new ArrayList<>();
83         for (IRemoteTest test : config.getTests()) {
84             if (test instanceof IShardableTest) {
85                 // Inject current information to help with sharding
86                 if (test instanceof IBuildReceiver) {
87                     ((IBuildReceiver) test).setBuild(context.getBuildInfos().get(0));
88                 }
89                 if (test instanceof IDeviceTest) {
90                     ((IDeviceTest) test).setDevice(context.getDevices().get(0));
91                 }
92                 if (test instanceof IMultiDeviceTest) {
93                     ((IMultiDeviceTest) test).setDeviceInfos(context.getDeviceBuildMap());
94                 }
95                 if (test instanceof IInvocationContextReceiver) {
96                     ((IInvocationContextReceiver) test).setInvocationContext(context);
97                 }
98 
99                 // Handling of the ITestSuite is a special case, we do not allow pool of tests
100                 // since each shard needs to be independent.
101                 if (test instanceof ITestSuite) {
102                     ((ITestSuite) test).setShouldMakeDynamicModule(false);
103                 }
104 
105                 Collection<IRemoteTest> subTests = ((IShardableTest) test).split(shardCount);
106                 if (subTests == null) {
107                     // test did not shard so we add it as is.
108                     allTests.add(test);
109                 } else {
110                     allTests.addAll(subTests);
111                 }
112             } else {
113                 // if test is not shardable we add it as is.
114                 allTests.add(test);
115             }
116         }
117         return allTests;
118     }
119 
120     /**
121      * Split the list of tests to run however the implementation see fit. Sharding needs to be
122      * consistent. It is acceptable to return an empty list if no tests can be run in the shard.
123      *
124      * <p>Implement this in order to provide a test suite specific sharding. The default
125      * implementation attempts to balance the number of IRemoteTest per shards as much as possible
126      * as a first step, then use a minor criteria or run-hint to adjust the lists a bit more.
127      *
128      * @param fullList the initial full list of {@link IRemoteTest} containing all the tests that
129      *     need to run.
130      * @param shardCount the total number of shard that need to run.
131      * @return a list of list {@link IRemoteTest}s that have been assigned to each shard. The list
132      *     size will be the shardCount.
133      */
134     @VisibleForTesting
splitTests(List<IRemoteTest> fullList, int shardCount)135     protected List<List<IRemoteTest>> splitTests(List<IRemoteTest> fullList, int shardCount) {
136         List<List<IRemoteTest>> shards = new ArrayList<>();
137         // We are using Match.ceil to avoid the last shard having too much extra.
138         int numPerShard = (int) Math.ceil(fullList.size() / (float) shardCount);
139 
140         boolean needsCorrection = false;
141         float correctionRatio = 0f;
142         if (fullList.size() > shardCount) {
143             // In some cases because of the Math.ceil, some combination might run out of tests
144             // before the last shard, in that case we populate a correction to rebalance the tests.
145             needsCorrection = (numPerShard * (shardCount - 1)) > fullList.size();
146             correctionRatio = numPerShard - ((fullList.size() / (float) shardCount));
147         }
148         // Recalculate the number of tests per shard with the correction taken into account.
149         numPerShard = (int) Math.floor(numPerShard - correctionRatio);
150         // Based of the parameters, distribute the tests accross shards.
151         shards = balancedDistrib(fullList, shardCount, numPerShard, needsCorrection);
152         // Do last minute rebalancing
153         topBottom(shards, shardCount);
154         return shards;
155     }
156 
balancedDistrib( List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection)157     private List<List<IRemoteTest>> balancedDistrib(
158             List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection) {
159         List<List<IRemoteTest>> shards = new ArrayList<>();
160         List<IRemoteTest> correctionList = new ArrayList<>();
161         int correctionSize = 0;
162 
163         // Generate all the shards
164         for (int i = 0; i < shardCount; i++) {
165             List<IRemoteTest> shardList;
166             if (i >= fullList.size()) {
167                 // Return empty list when we don't have enough tests for all the shards.
168                 shardList = new ArrayList<IRemoteTest>();
169                 shards.add(shardList);
170                 continue;
171             }
172 
173             if (i == shardCount - 1) {
174                 // last shard take everything remaining except the correction:
175                 if (needsCorrection) {
176                     // We omit the size of the correction needed.
177                     correctionSize = fullList.size() - (numPerShard + (i * numPerShard));
178                     correctionList =
179                             fullList.subList(fullList.size() - correctionSize, fullList.size());
180                 }
181                 shardList = fullList.subList(i * numPerShard, fullList.size() - correctionSize);
182                 shards.add(new ArrayList<>(shardList));
183                 continue;
184             }
185             shardList = fullList.subList(i * numPerShard, numPerShard + (i * numPerShard));
186             shards.add(new ArrayList<>(shardList));
187         }
188 
189         // If we have correction omitted tests, disperse them on each shard, at this point the
190         // number of tests in correction is ensured to be bellow the number of shards.
191         for (int i = 0; i < shardCount; i++) {
192             if (i < correctionList.size()) {
193                 shards.get(i).add(correctionList.get(i));
194             } else {
195                 break;
196             }
197         }
198         return shards;
199     }
200 
201     /**
202      * Move around predictably the tests in order to have a better uniformization of the tests in
203      * each shard.
204      */
normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount)205     private void normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount) {
206         final int numRound = shardCount;
207         final int distance = shardCount - 1;
208         for (int i = 0; i < numRound; i++) {
209             for (int j = 0; j < listAllTests.size(); j = j + distance) {
210                 // Push the test at the end
211                 IRemoteTest push = listAllTests.remove(j);
212                 listAllTests.add(push);
213             }
214         }
215     }
216 
217     /**
218      * Special handling for suite from {@link ITestSuite}. We aggregate the tests in the same shard
219      * in order to optimize target_preparation step.
220      *
221      * @param tests the {@link List} of {@link IRemoteTest} for that shard.
222      */
aggregateSuiteModules(List<IRemoteTest> tests)223     private void aggregateSuiteModules(List<IRemoteTest> tests) {
224         List<IRemoteTest> dupList = new ArrayList<>(tests);
225         for (int i = 0; i < dupList.size(); i++) {
226             if (dupList.get(i) instanceof ITestSuite) {
227                 // We iterate the other tests to see if we can find another from the same module.
228                 for (int j = i + 1; j < dupList.size(); j++) {
229                     // If the test was not already merged
230                     if (tests.contains(dupList.get(j))) {
231                         if (dupList.get(j) instanceof ITestSuite) {
232                             if (ModuleMerger.arePartOfSameSuite(
233                                     (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j))) {
234                                 ModuleMerger.mergeSplittedITestSuite(
235                                         (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j));
236                                 tests.remove(dupList.get(j));
237                             }
238                         }
239                     }
240                 }
241             }
242         }
243     }
244 
topBottom(List<List<IRemoteTest>> allShards, int shardCount)245     private void topBottom(List<List<IRemoteTest>> allShards, int shardCount) {
246         // We only attempt this when the number of shard is pretty high
247         if (shardCount < 4) {
248             return;
249         }
250         // Generate approximate RuntimeHint for each shard
251         int index = 0;
252         List<SortShardObj> shardTimes = new ArrayList<>();
253         for (List<IRemoteTest> shard : allShards) {
254             long aggTime = 0l;
255             CLog.d("++++++++++++++++++ SHARD %s +++++++++++++++", index);
256             for (IRemoteTest test : shard) {
257                 if (test instanceof IRuntimeHintProvider) {
258                     aggTime += ((IRuntimeHintProvider) test).getRuntimeHint();
259                 }
260             }
261             CLog.d("Shard %s approximate time: %s", index, TimeUtil.formatElapsedTime(aggTime));
262             shardTimes.add(new SortShardObj(index, aggTime));
263             index++;
264             CLog.d("+++++++++++++++++++++++++++++++++++++++++++");
265         }
266 
267         Collections.sort(shardTimes);
268         if ((shardTimes.get(0).mAggTime - shardTimes.get(shardTimes.size() - 1).mAggTime)
269                 < 60 * 60 * 1000l) {
270             return;
271         }
272 
273         // take 30% top shard (10 shard = top 3 shards)
274         for (int i = 0; i < (shardCount * 0.3); i++) {
275             CLog.d(
276                     "Top shard %s is index %s with %s",
277                     i,
278                     shardTimes.get(i).mIndex,
279                     TimeUtil.formatElapsedTime(shardTimes.get(i).mAggTime));
280             int give = shardTimes.get(i).mIndex;
281             int receive = shardTimes.get(shardTimes.size() - 1 - i).mIndex;
282             CLog.d("Giving from shard %s to shard %s", give, receive);
283             for (int j = 0; j < (allShards.get(give).size() * (0.2f / (i + 1))); j++) {
284                 IRemoteTest givetest = allShards.get(give).remove(0);
285                 allShards.get(receive).add(givetest);
286             }
287         }
288     }
289 
290     /** Object holder for shard, their index and their aggregated execution time. */
291     private class SortShardObj implements Comparable<SortShardObj> {
292         public final int mIndex;
293         public final Long mAggTime;
294 
SortShardObj(int index, long aggTime)295         public SortShardObj(int index, long aggTime) {
296             mIndex = index;
297             mAggTime = aggTime;
298         }
299 
300         @Override
compareTo(SortShardObj obj)301         public int compareTo(SortShardObj obj) {
302             return obj.mAggTime.compareTo(mAggTime);
303         }
304     }
305 }
306