1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MULTI_STREAM_CONTROLLER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MULTI_STREAM_CONTROLLER_H_ 19 20 #include <atomic> 21 #include <functional> 22 #include <list> 23 #include <memory> 24 #include <mutex> 25 #include <unordered_map> 26 #include <utility> 27 #include <vector> 28 29 #include "utils/log_adapter.h" 30 #include "include/backend/mem_reuse/mem_dynamic_allocator.h" 31 #include "include/backend/visible.h" 32 #include "runtime/hardware/device_context.h" 33 34 namespace mindspore { 35 namespace device { 36 template <typename T> 37 struct AtomicWrapper { AtomicWrapperAtomicWrapper38 AtomicWrapper() : value_(0L) {} AtomicWrapperAtomicWrapper39 explicit AtomicWrapper(const std::atomic<T> &value) : value_(value.load()) {} AtomicWrapperAtomicWrapper40 AtomicWrapper(const AtomicWrapper &other) : value_(other.value_.load()) {} 41 AtomicWrapper &operator=(const AtomicWrapper &other) { value_.store(other.value_.load()); } 42 43 std::atomic<T> value_; 44 }; 45 46 class BACKEND_EXPORT TaskIdOnStreamManager { 47 public: 48 TaskIdOnStreamManager() = default; 49 50 void Resize(uint32_t stream_size); 51 52 int64_t Query(uint32_t user_stream_id, uint32_t memory_stream_id); 53 54 bool Update(int64_t task_id_on_stream, uint32_t user_stream_id, uint32_t memory_stream_id); 55 56 int64_t Launch(uint32_t stream_id); 57 58 int64_t Get(uint32_t stream_id); 59 60 private: 61 std::mutex mutex_; 62 bool initialized_{false}; 63 uint32_t initialize_size_{0}; 64 std::vector<AtomicWrapper<int64_t>> generator_; 65 std::vector<std::vector<int64_t>> status_; 66 }; 67 68 // Event pool recycled with ref count, pool will reuse event when cannot create more events. 69 class BACKEND_EXPORT EventPool { 70 public: 71 explicit EventPool(std::function<DeviceEventPtr(void)> event_creator); 72 ~EventPool(); 73 74 EventPool() = delete; 75 EventPool(const EventPool &) = delete; 76 EventPool &operator=(const EventPool &) = delete; 77 78 // Get event from pool, event was wrapper by shared_ptr. 79 DeviceEventPtr Get(); 80 81 private: 82 std::mutex mutex_; 83 bool expired_{false}; 84 // Pool will just create event before reach core size, use half of size limits as core size. 85 size_t core_size_{32768}; 86 size_t size_{0}; 87 std::function<DeviceEventPtr(void)> event_creator_; 88 std::list<DeviceEvent *> events_; 89 // cached_events_ hold shared ptr of event, since device res manager return a smart pointer. 90 std::list<DeviceEventPtr> cached_events_; 91 }; 92 using EventPoolPtr = std::shared_ptr<EventPool>; 93 94 class MultiStreamController; 95 using MultiStreamControllerPtr = std::shared_ptr<MultiStreamController>; 96 97 class BACKEND_EXPORT MultiStreamController { 98 public: 99 MultiStreamController() = default; 100 MultiStreamController(const MultiStreamController &) = delete; 101 MultiStreamController &operator=(const MultiStreamController &) = delete; 102 ~MultiStreamController() = default; 103 104 static MultiStreamControllerPtr &GetInstance(); 105 106 void Refresh(const DeviceContext *device_context); 107 bool UpdateTaskIdOnStream(const DeviceContext *device_context, int64_t task_id_on_stream, uint32_t user_stream_id, 108 uint32_t memory_stream_id); 109 int64_t QueryTaskIdOnStream(const DeviceContext *device_context, uint32_t user_stream_id, uint32_t memory_stream_id); 110 int64_t LaunchTaskIdOnStream(const DeviceContext *device_context, uint32_t stream_id); 111 int64_t GetTaskIdOnStream(const DeviceContext *device_context, uint32_t stream_id); 112 113 std::mutex &GetStreamMutex(const DeviceContext *device_context, size_t stream_id); 114 115 // memory_stream_addresses pair : memory_stream_id, address. 116 bool RecordEvent(const DeviceContext *device_context, int64_t task_id_on_stream, uint32_t user_stream_id, 117 const std::vector<std::pair<uint32_t, DeviceMemPtr>> &memory_stream_addresses); 118 bool WaitEvent(const DeviceContext *device_context, int64_t task_id_on_stream, uint32_t user_stream_id, 119 uint32_t memory_stream_id); 120 bool WaitEvent(const DeviceContext *device_context, int64_t task_id_on_stream, uint32_t user_stream_id); 121 bool DispatchRecordWaitEvent(const DeviceContext *device_context, uint32_t user_stream_id, uint32_t memory_stream_id); 122 123 bool SyncStream(const DeviceContext *device_context, size_t stream_id); 124 bool SyncAllStreams(const DeviceContext *device_context); 125 bool SyncNotDefaultStreams(const DeviceContext *device_context); 126 127 private: 128 std::unordered_map<const DeviceContext *, TaskIdOnStreamManager> task_id_on_stream_manager_; 129 std::unordered_map<const DeviceContext *, std::unordered_map<uint32_t, std::mutex>> stream_mutexes_; 130 std::unordered_map<const DeviceContext *, EventPoolPtr> event_pools_; 131 }; 132 } // namespace device 133 } // namespace mindspore 134 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MULTI_STREAM_CONTROLLER_H_ 135