1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_MP_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_MP_H_ 19 20 #include <functional> 21 #include <map> 22 #include <mutex> 23 #include <thread> 24 #include <vector> 25 26 #ifdef ENABLE_PYTHON 27 #include "pybind11/pybind11.h" 28 #include "pybind11/stl.h" 29 namespace py = pybind11; 30 #endif 31 #include "minddata/dataset/util/status.h" 32 33 namespace mindspore { 34 namespace dataset { 35 class PythonMultiprocessingRuntime { 36 public: 37 virtual void launch(int32_t id) = 0; 38 virtual void terminate() = 0; 39 virtual bool is_mp_enabled() = 0; 40 virtual void add_new_workers(int32_t num_new_workers) = 0; 41 virtual void remove_workers(int32_t num_removed_workers) = 0; 42 virtual std::vector<int32_t> get_pids() = 0; 43 virtual bool is_running() = 0; ~PythonMultiprocessingRuntime()44 virtual ~PythonMultiprocessingRuntime() {} 45 virtual void set_thread_to_worker(int32_t worker_id) = 0; 46 virtual Status get_thread_to_worker(int32_t *const worker_id) const = 0; 47 virtual void reset() = 0; 48 }; 49 50 #ifdef ENABLE_PYTHON 51 class PyPythonMultiprocessingRuntime : public PythonMultiprocessingRuntime { 52 public: 53 // inherit constructors 54 using PythonMultiprocessingRuntime::PythonMultiprocessingRuntime; 55 // Trampoline (need one for each virtual function) 56 // PYBIND11_OVERLOAD_PURE(void, /* Return type */ 57 // PythonMultiprocessingRuntime, /* Parent class */ 58 // launch /* Name of function in C++ (must match Python name) */ 59 launch(int32_t id)60 void launch(int32_t id) override { PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, launch, id); } terminate()61 void terminate() override { PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, terminate); } is_mp_enabled()62 bool is_mp_enabled() override { PYBIND11_OVERLOAD_PURE(bool, PythonMultiprocessingRuntime, is_mp_enabled); } add_new_workers(int32_t num_workers)63 void add_new_workers(int32_t num_workers) override { 64 PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, add_new_workers, num_workers); 65 } remove_workers(int32_t num_workers)66 void remove_workers(int32_t num_workers) override { 67 PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, remove_workers, num_workers); 68 } get_pids()69 std::vector<int32_t> get_pids() override { 70 PYBIND11_OVERLOAD_PURE(std::vector<int32_t>, PythonMultiprocessingRuntime, get_pids); 71 } 72 set_thread_to_worker(int32_t worker_id)73 void set_thread_to_worker(int32_t worker_id) override { 74 std::lock_guard<std::mutex> guard(lock_); 75 threads_to_workers_[std::this_thread::get_id()] = worker_id; 76 } 77 get_thread_to_worker(int32_t * const worker_id)78 Status get_thread_to_worker(int32_t *const worker_id) const override { 79 auto itr = threads_to_workers_.find(std::this_thread::get_id()); 80 CHECK_FAIL_RETURN_UNEXPECTED(itr != threads_to_workers_.end(), "[Internal] This thread is not a worker!"); 81 *worker_id = itr->second; 82 return Status::OK(); 83 } 84 reset()85 void reset() override { threads_to_workers_.clear(); } 86 is_running()87 bool is_running() override { PYBIND11_OVERLOAD_PURE(bool, PythonMultiprocessingRuntime, is_running); } 88 89 private: 90 std::map<std::thread::id, int32_t> threads_to_workers_{}; 91 std::mutex lock_; // used when writing into threads_to_workers_ 92 }; 93 #endif 94 } // namespace dataset 95 } // namespace mindspore 96 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_MP_H_ 97