• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <set>
23 #include <mutex>
24 
25 #include "acl/acl_rt.h"
26 #include "plugin/device/ascend/hal/common/ascend_utils.h"
27 #include "utils/hash_map.h"
28 
29 namespace mindspore {
30 namespace device {
31 namespace ascend {
32 class AscendStreamMng {
33  public:
34   static AscendStreamMng &GetInstance();
35 
~AscendStreamMng()36   ~AscendStreamMng() {
37 #ifdef WITH_BACKEND
38     for (auto iter = stream_call_backs_.begin(); iter != stream_call_backs_.end();) {
39       aclrtStream stream = iter->first;
40       iter++;
41       UnRegCallback(stream);
42     }
43 #endif
44   }
45 
ResetResource()46   void ResetResource() {
47     cur_stream_num_ = 0;
48     cur_event_num_ = 0;
49   }
50 
ApplyNewStream()51   uint32_t ApplyNewStream() { return cur_stream_num_++; }
52 
ApplyNewEvent()53   uint32_t ApplyNewEvent() { return cur_event_num_++; }
54 
55   aclrtEvent ApplyRtEvent();
56   aclrtEvent ApplyRtEventWithFlag(uint32_t flag);
57   uint32_t GetRtEventId(const aclrtEvent &event) const;
58   void DestroyAllRtEvents();
59 
60   void DeleteEvent();
61 
62   void DeleteStream();
63 
64   uint32_t GetCurAllocStreamId() const;
65 
cur_stream_num()66   uint32_t cur_stream_num() const { return cur_stream_num_; }
67 
cur_event_num()68   uint32_t cur_event_num() const { return cur_event_num_; }
69 
70   void CreateStream(aclrtStream *stream, int32_t priority = 0);
71   void CreateStream(size_t *stream_id, int32_t priority = 0);
72   void RegCallback(aclrtStream stream);
73   void UnRegCallback(aclrtStream stream);
74   void CreateStreamWithFlags(aclrtStream *stream, uint32_t flags, int32_t priority = 0);
75   void CreateStreamWithFlags(size_t *stream_id, uint32_t flags, int32_t priority = 0);
76   bool DestroyStream(size_t stream_id);
77   bool DestroyAllStreams();
78   aclrtStream GetStream(size_t stream_id) const;
79   bool SyncStream(size_t stream_id) const;
80   bool SyncStream(aclrtStream stream) const;
81   bool SyncAllStreams() const;
82   bool SyncNotDefaultStreams() const;
83   // Sync all streams except the streams in except_streams.
84   bool SyncExceptStreamsInList(const std::set<aclrtStream> &except_streams) const;
85   size_t QueryStreamSize() const;
86   bool QueryStream(size_t stream_id);
87   size_t GetStreamId(void *stream_ptr);
88   std::vector<uint32_t> GetStreamIds() const;
SetBusyStreamNum(uint32_t stream_num)89   void SetBusyStreamNum(uint32_t stream_num) { busy_stream_num_ = stream_num; }
GetBusyStreamNum()90   uint32_t GetBusyStreamNum() const { return busy_stream_num_; }
SetCopyStream(aclrtStream stream)91   void SetCopyStream(aclrtStream stream) { copy_stream_ = stream; }
GetCopyStream()92   aclrtStream GetCopyStream() const { return copy_stream_; }
93 
set_current_stream(size_t stream_id)94   void set_current_stream(size_t stream_id) { current_stream_id_ = stream_id; }
current_stream()95   size_t current_stream() const { return current_stream_id_; }
96 
default_stream_id()97   size_t default_stream_id() const { return default_stream_id_; }
98 
single_op_multi_stream_enable()99   bool single_op_multi_stream_enable() const { return single_op_multi_stream_enable_; }
set_single_op_multi_stream_enable(bool single_op_multi_stream_enable)100   void set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) {
101     single_op_multi_stream_enable_ = single_op_multi_stream_enable;
102   }
103 
enable_callback(bool is_enable_callback)104   void enable_callback(bool is_enable_callback) { is_enable_callback_ = is_enable_callback; }
is_enable_callback()105   bool is_enable_callback() { return is_enable_callback_; }
106 
107  private:
108   // Count streams and events number in task sink scenario
109   uint32_t cur_stream_num_{0};
110   uint32_t cur_event_num_{0};
111 
112   // The max stream num on device ar a time
113   uint32_t busy_stream_num_{0};
114 
115   // Ensure the thread safety for creating and destroying stream.
116   std::mutex stream_mutex_;
117   aclrtStream copy_stream_{nullptr};
118 
119   // all gpu CUDA streams including default_stream_.
120   std::vector<void *> streams_;
121   std::vector<aclrtEvent> events_{};
122 
123   // Currently using stream id.
124   size_t current_stream_id_{0};
125 
126   // Default stream. We consider the first stream created as default stream.
127   void *default_stream_{nullptr};
128   size_t default_stream_id_{0};
129   bool single_op_multi_stream_enable_{false};
130 
131   // Flag of registering callback or not, default value is false.
132   // When multi streams are created, or gmem is enabled, this flag would change to ture.
133   bool is_enable_callback_{false};
134   // This vector used for simplify logic of tracing multi stream creates.
135   std::vector<aclrtStream> callback_cached_streams_;
136   mindspore::HashMap<aclrtStream, CallbackThreadPtr> stream_call_backs_;
137 };
138 }  // namespace ascend
139 }  // namespace device
140 }  // namespace mindspore
141 
142 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
143