// Copyright (C) 2017 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include "base/Compiler.h" #include "base/Optional.h" #include "base/System.h" #include "base/WorkerThread.h" // // ThreadPool - a simple collection of worker threads to process enqueued // items on multiple cores. // // To create a thread pool supply a processing function and an optional number // of threads to use (default is number of CPU cores). // Thread pool distributes the work in simple round robin manner over all its // workers - this means individual items should be simple and take similar time // to process. // // Usage is very similar to one of WorkerThread, with difference being in the // number of worker threads used and in existence of explicit done() method: // // struct WorkItem { int number; }; // // ThreadPool tp([](WorkItem&& item) { std::cout << item.num; }); // CHECK(tp.start()) << "Failed to start the thread pool"; // tp.enqueue({1}); // tp.enqueue({2}); // tp.enqueue({3}); // tp.enqueue({4}); // tp.enqueue({5}); // tp.done(); // tp.join(); // // Make sure that the processing function won't block worker threads - thread // pool has no way of detecting it and may potentially get all workers to block, // resulting in a hanging application. // namespace android { namespace base { using ThreadPoolWorkerId = uint32_t; template class ThreadPool { DISALLOW_COPY_AND_ASSIGN(ThreadPool); public: using Item = ItemT; using WorkerId = ThreadPoolWorkerId; using Processor = std::function; private: struct Command { Item mItem; WorkerId mWorkerId; Command(Item&& item, WorkerId workerId) : mItem(std::move(item)), mWorkerId(workerId) {} DISALLOW_COPY_AND_ASSIGN(Command); Command(Command&&) = default; }; using Worker = WorkerThread>; public: // Fn is the type of the processor, it can either have 2 parameters: 1 for the Item, 1 for the // WorkerId, or have only 1 Item parameter. template || std::is_invocable_v>> ThreadPool(int threads, Fn&& processor) : mProcessor() { if constexpr (std::is_invocable_v) { mProcessor = std::move(processor); } else if constexpr (std::is_invocable_v) { using namespace std::placeholders; mProcessor = std::bind(std::move(processor), _1); } if (threads < 1) { threads = android::base::getCpuCoreCount(); } mWorkers = std::vector>(threads); for (auto& workerPtr : mWorkers) { workerPtr.emplace([this](Optional&& commandOpt) { if (!commandOpt) { return Worker::Result::Stop; } Command command = std::move(commandOpt.value()); mProcessor(std::move(command.mItem), command.mWorkerId); return Worker::Result::Continue; }); } } explicit ThreadPool(Processor&& processor) : ThreadPool(0, std::move(processor)) {} ~ThreadPool() { done(); join(); } bool start() { for (auto& workerPtr : mWorkers) { if (workerPtr->start()) { ++mValidWorkersCount; } else { workerPtr.clear(); } } return mValidWorkersCount > 0; } void done() { for (auto& workerPtr : mWorkers) { if (workerPtr) { workerPtr->enqueue(kNullopt); } } } void join() { for (auto& workerPtr : mWorkers) { if (workerPtr) { workerPtr->join(); } } mWorkers.clear(); mValidWorkersCount = 0; } void enqueue(Item&& item) { for (;;) { int currentIndex = mNextWorkerIndex.fetch_add(1, std::memory_order_relaxed); int workerIndex = currentIndex % mWorkers.size(); auto& workerPtr = mWorkers[workerIndex]; if (workerPtr) { Command command(std::forward(item), workerIndex); workerPtr->enqueue(std::move(command)); break; } } } // The itemFactory will be called multiple times to generate one item for each worker thread. template >> void broadcast(Fn&& itemFactory) { int i = 0; for (auto& workerOpt : mWorkers) { if (!workerOpt) continue; Command command(std::move(itemFactory()), i); workerOpt->enqueue(std::move(command)); ++i; } } void waitAllItems() { if (0 == mValidWorkersCount) return; for (auto& workerOpt : mWorkers) { if (!workerOpt) continue; workerOpt->waitQueuedItems(); } } int numWorkers() const { return mValidWorkersCount; } private: Processor mProcessor; std::vector> mWorkers; std::atomic mNextWorkerIndex{0}; int mValidWorkersCount{0}; }; } // namespace base } // namespace android