• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_MP_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_MP_H_
19 
20 #include <functional>
21 #include <map>
22 #include <mutex>
23 #include <thread>
24 #include <vector>
25 
26 #ifdef ENABLE_PYTHON
27 #include "pybind11/pybind11.h"
28 #include "pybind11/stl.h"
29 namespace py = pybind11;
30 #endif
31 #include "minddata/dataset/util/status.h"
32 
33 namespace mindspore {
34 namespace dataset {
35 class PythonMultiprocessingRuntime {
36  public:
37   virtual void launch(int32_t id) = 0;
38   virtual void terminate() = 0;
39   virtual bool is_mp_enabled() = 0;
40   virtual void add_new_workers(int32_t num_new_workers) = 0;
41   virtual void remove_workers(int32_t num_removed_workers) = 0;
42   virtual std::vector<int32_t> get_pids() = 0;
43   virtual bool is_running() = 0;
~PythonMultiprocessingRuntime()44   virtual ~PythonMultiprocessingRuntime() {}
45   virtual void set_thread_to_worker(int32_t worker_id) = 0;
46   virtual Status get_thread_to_worker(int32_t *const worker_id) const = 0;
47   virtual void reset() = 0;
48 };
49 
50 #ifdef ENABLE_PYTHON
51 class PyPythonMultiprocessingRuntime : public PythonMultiprocessingRuntime {
52  public:
53   // inherit constructors
54   using PythonMultiprocessingRuntime::PythonMultiprocessingRuntime;
55   //  Trampoline (need one for each virtual function)
56   //  PYBIND11_OVERLOAD_PURE(void,                         /* Return type */
57   //                        PythonMultiprocessingRuntime,  /* Parent class */
58   //                        launch                         /* Name of function in C++ (must match Python name) */
59 
launch(int32_t id)60   void launch(int32_t id) override { PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, launch, id); }
terminate()61   void terminate() override { PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, terminate); }
is_mp_enabled()62   bool is_mp_enabled() override { PYBIND11_OVERLOAD_PURE(bool, PythonMultiprocessingRuntime, is_mp_enabled); }
add_new_workers(int32_t num_workers)63   void add_new_workers(int32_t num_workers) override {
64     PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, add_new_workers, num_workers);
65   }
remove_workers(int32_t num_workers)66   void remove_workers(int32_t num_workers) override {
67     PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, remove_workers, num_workers);
68   }
get_pids()69   std::vector<int32_t> get_pids() override {
70     PYBIND11_OVERLOAD_PURE(std::vector<int32_t>, PythonMultiprocessingRuntime, get_pids);
71   }
72 
set_thread_to_worker(int32_t worker_id)73   void set_thread_to_worker(int32_t worker_id) override {
74     std::lock_guard<std::mutex> guard(lock_);
75     threads_to_workers_[std::this_thread::get_id()] = worker_id;
76   }
77 
get_thread_to_worker(int32_t * const worker_id)78   Status get_thread_to_worker(int32_t *const worker_id) const override {
79     auto itr = threads_to_workers_.find(std::this_thread::get_id());
80     CHECK_FAIL_RETURN_UNEXPECTED(itr != threads_to_workers_.end(), "[Internal] This thread is not a worker!");
81     *worker_id = itr->second;
82     return Status::OK();
83   }
84 
reset()85   void reset() override { threads_to_workers_.clear(); }
86 
is_running()87   bool is_running() override { PYBIND11_OVERLOAD_PURE(bool, PythonMultiprocessingRuntime, is_running); }
88 
89  private:
90   std::map<std::thread::id, int32_t> threads_to_workers_{};
91   std::mutex lock_;  // used when writing into threads_to_workers_
92 };
93 #endif
94 }  // namespace dataset
95 }  // namespace mindspore
96 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_MP_H_
97