1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TASK_EXECUTOR_H_ 18 #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TASK_EXECUTOR_H_ 19 20 #include <functional> 21 #include <queue> 22 #include <mutex> 23 #include <vector> 24 #include <thread> 25 #include <condition_variable> 26 27 #include "utils/log_adapter.h" 28 #include "ps/constants.h" 29 30 namespace mindspore { 31 namespace ps { 32 namespace core { 33 /* This class can submit tasks in multiple threads 34 * example: 35 * void TestTaskExecutor() { 36 * std::cout << "Execute in one thread"; 37 * } 38 * 39 * TaskExecutor executor(10); // 10 threads 40 * executor.Submit(TestTaskExecutor, this); // Submit task 41 */ 42 class TaskExecutor { 43 public: 44 explicit TaskExecutor(size_t thread_num, size_t max_task_num = kMaxTaskNum, 45 size_t submit_timeout = kSubmitTimeOutInMs); 46 ~TaskExecutor(); 47 48 // If the number of submitted tasks is greater than the size of the queue, it will block the submission of subsequent 49 // tasks unitl timeout. 50 template <typename Fun, typename... Args> Submit(Fun && function,Args &&...args)51 bool Submit(Fun &&function, Args &&... args) { 52 auto callee = std::bind(function, args...); 53 std::function<void()> task = [callee]() -> void { callee(); }; 54 size_t index = 0; 55 for (size_t i = 0; i < submit_timeout_; i++) { 56 std::unique_lock<std::mutex> lock(mtx_); 57 if (task_num_ >= max_task_num_) { 58 lock.unlock(); 59 std::this_thread::sleep_for(std::chrono::milliseconds(kSubmitTaskIntervalInMs)); 60 index++; 61 } else { 62 break; 63 } 64 } 65 if (index >= submit_timeout_) { 66 MS_LOG(WARNING) << "Submit task failed after " << submit_timeout_ << " ms."; 67 return false; 68 } 69 std::unique_lock<std::mutex> lock(mtx_); 70 task_num_++; 71 task_queue_.push(task); 72 return true; 73 } 74 75 private: 76 bool running_; 77 78 // The number of tasks actually running 79 size_t thread_num_; 80 // The number of idle threads that can execute tasks 81 size_t idle_thread_num_; 82 83 // The timeout period of the task submission, in milliseconds. default timeout is 3000 milliseconds. 84 size_t submit_timeout_; 85 86 // The maximum number of tasks that can be submitted to the task queue, If the number of submitted tasks exceeds this 87 // max_task_num_, the Submit function will block.Until the current number of tasks is less than max task num,or 88 // timeout. 89 size_t max_task_num_; 90 // The number of currently submitted to the task queue 91 size_t task_num_; 92 93 std::thread notify_thread_; 94 std::mutex mtx_; 95 std::condition_variable cv_; 96 97 std::vector<std::thread> working_threads_; 98 std::queue<std::function<void()>> task_queue_; 99 }; 100 } // namespace core 101 } // namespace ps 102 } // namespace mindspore 103 #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TASK_EXECUTOR_H_ 104