1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.core.test; 18 19 import static java.util.concurrent.TimeUnit.MILLISECONDS; 20 21 import android.annotation.SuppressLint; 22 import android.content.Context; 23 import android.content.Intent; 24 import android.util.Log; 25 26 import com.android.nn.benchmark.core.Processor; 27 import com.android.nn.crashtest.core.CrashTest; 28 import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer; 29 import com.android.nn.benchmark.core.TfLiteBackend; 30 31 import java.time.Duration; 32 import java.util.ArrayList; 33 import java.util.Collections; 34 import java.util.HashSet; 35 import java.util.List; 36 import java.util.Optional; 37 import java.util.Set; 38 import java.util.concurrent.CountDownLatch; 39 import java.util.concurrent.ExecutionException; 40 import java.util.concurrent.ExecutorService; 41 import java.util.concurrent.Executors; 42 import java.util.concurrent.Future; 43 44 public class RunModelsInParallel implements CrashTest { 45 46 private static final String MODELS = "models"; 47 private static final String DURATION = "duration"; 48 private static final String THREADS = "thread_counts"; 49 private static final String TEST_NAME = "test_name"; 50 private static final String ACCELERATOR_NAME = "accelerator_name"; 51 private static final String IGNORE_UNSUPPORTED_MODELS = "ignore_unsupported_models"; 52 private static final String RUN_MODEL_COMPILATION_ONLY = "run_model_compilation_only"; 53 private static final String MEMORY_MAP_MODEL = "memory_map_model"; 54 55 private final Set<Processor> activeTests = new HashSet<>(); 56 private final List<Boolean> mTestCompletionResults = Collections.synchronizedList( 57 new ArrayList<>()); 58 private long mTestDurationMillis = 0; 59 private int mThreadCount = 0; 60 private int[] mTestList = new int[0]; 61 private String mTestName; 62 private String mAcceleratorName; 63 private boolean mIgnoreUnsupportedModels; 64 private Context mContext; 65 private boolean mRunModelCompilationOnly; 66 private ExecutorService mExecutorService = null; 67 private CountDownLatch mParallelTestComplete; 68 private ProgressListener mProgressListener; 69 private boolean mMmapModel; 70 intentInitializer(int[] models, int threadCount, Duration duration, String testName, String acceleratorName, boolean ignoreUnsupportedModels, boolean runModelCompilationOnly, boolean mmapModel)71 static public CrashTestIntentInitializer intentInitializer(int[] models, int threadCount, 72 Duration duration, String testName, String acceleratorName, 73 boolean ignoreUnsupportedModels, 74 boolean runModelCompilationOnly, boolean mmapModel) { 75 return intent -> { 76 intent.putExtra(MODELS, models); 77 intent.putExtra(DURATION, duration.toMillis()); 78 intent.putExtra(THREADS, threadCount); 79 intent.putExtra(TEST_NAME, testName); 80 intent.putExtra(ACCELERATOR_NAME, acceleratorName); 81 intent.putExtra(IGNORE_UNSUPPORTED_MODELS, ignoreUnsupportedModels); 82 intent.putExtra(RUN_MODEL_COMPILATION_ONLY, runModelCompilationOnly); 83 intent.putExtra(MEMORY_MAP_MODEL, mmapModel); 84 }; 85 } 86 87 @Override init(Context context, Intent configParams, Optional<ProgressListener> progressListener)88 public void init(Context context, Intent configParams, 89 Optional<ProgressListener> progressListener) { 90 mTestList = configParams.getIntArrayExtra(MODELS); 91 mThreadCount = configParams.getIntExtra(THREADS, 10); 92 mTestDurationMillis = configParams.getLongExtra(DURATION, 1000 * 60 * 10); 93 mTestName = configParams.getStringExtra(TEST_NAME); 94 mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME); 95 mIgnoreUnsupportedModels = mAcceleratorName != null && configParams.getBooleanExtra( 96 IGNORE_UNSUPPORTED_MODELS, false); 97 mRunModelCompilationOnly = configParams.getBooleanExtra(RUN_MODEL_COMPILATION_ONLY, false); 98 mMmapModel = configParams.getBooleanExtra(MEMORY_MAP_MODEL, false); 99 mContext = context; 100 mProgressListener = progressListener.orElseGet(() -> (Optional<String> message) -> { 101 Log.v(CrashTest.TAG, message.orElse(".")); 102 }); 103 mExecutorService = Executors.newFixedThreadPool(mThreadCount); 104 mTestCompletionResults.clear(); 105 } 106 107 @Override call()108 public Optional<String> call() { 109 mParallelTestComplete = new CountDownLatch(mThreadCount); 110 for (int i = 0; i < mThreadCount; i++) { 111 Processor testProcessor = createSubTestRunner(mTestList, i); 112 113 activeTests.add(testProcessor); 114 mExecutorService.submit(testProcessor); 115 } 116 117 return completedSuccessfully(); 118 } 119 createSubTestRunner(final int[] testList, final int testIndex)120 private Processor createSubTestRunner(final int[] testList, final int testIndex) { 121 final Processor result = new Processor(mContext, new Processor.Callback() { 122 @SuppressLint("DefaultLocale") 123 @Override 124 public void onBenchmarkFinish(boolean ok) { 125 notifyProgress("Test '%s': Benchmark #%d completed %s", mTestName, testIndex, 126 ok ? "successfully" : "with failure"); 127 mTestCompletionResults.add(ok); 128 mParallelTestComplete.countDown(); 129 } 130 131 @Override 132 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 133 } 134 }, testList); 135 result.setTfLiteBackend(TfLiteBackend.NNAPI); 136 result.setCompleteInputSet(false); 137 result.setNnApiAcceleratorName(mAcceleratorName); 138 result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels); 139 result.setRunModelCompilationOnly(mRunModelCompilationOnly); 140 result.setMmapModel(mMmapModel); 141 return result; 142 } 143 endTests()144 private void endTests() { 145 ExecutorService terminatorsThreadPool = Executors.newFixedThreadPool(activeTests.size()); 146 List<Future<?>> terminationCommands = new ArrayList<>(); 147 for (final Processor test : activeTests) { 148 // Exit will block until the thread is completed 149 terminationCommands.add(terminatorsThreadPool.submit( 150 () -> test.exitWithTimeout(Duration.ofSeconds(20).toMillis()))); 151 } 152 terminationCommands.forEach(terminationCommand -> { 153 try { 154 terminationCommand.get(); 155 } catch (ExecutionException e) { 156 Log.w(TAG, "Failure while waiting for completion of tests", e); 157 } catch (InterruptedException e) { 158 Thread.interrupted(); 159 } 160 }); 161 } 162 163 @SuppressLint("DefaultLocale") notifyProgress(String messageFormat, Object... args)164 void notifyProgress(String messageFormat, Object... args) { 165 mProgressListener.testProgress(Optional.of(String.format(messageFormat, args))); 166 } 167 168 // This method blocks until the tests complete and returns true if all tests completed 169 // successfully 170 @SuppressLint("DefaultLocale") completedSuccessfully()171 private Optional<String> completedSuccessfully() { 172 try { 173 boolean testsEnded = mParallelTestComplete.await(mTestDurationMillis, MILLISECONDS); 174 if (!testsEnded) { 175 Log.i(TAG, 176 String.format( 177 "Test '%s': Tests are not completed (they might have been " 178 + "designed to run " 179 + "indefinitely. Forcing termination.", mTestName)); 180 endTests(); 181 } 182 } catch (InterruptedException ignored) { 183 Thread.currentThread().interrupt(); 184 } 185 186 final long failedTestCount = mTestCompletionResults.stream().filter( 187 testResult -> !testResult).count(); 188 if (failedTestCount > 0) { 189 String failureMsg = String.format("Test '%s': %d out of %d test failed", mTestName, 190 failedTestCount, 191 mTestCompletionResults.size()); 192 Log.w(CrashTest.TAG, failureMsg); 193 return failure(failureMsg); 194 } else { 195 Log.i(CrashTest.TAG, 196 String.format("Test '%s': Test completed successfully", mTestName)); 197 return success(); 198 } 199 } 200 } 201