• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TASK_STREAM_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TASK_STREAM_H_
19 
20 #include <new>
21 #include <unordered_map>
22 #include <vector>
23 #include <memory>
24 #include "runtime/base.h"
25 #include "utils/log_adapter.h"
26 
27 namespace mindspore {
28 namespace kernel {
29 class TaskStream {
30  public:
31   TaskStream() = default;
32   ~TaskStream() = default;
GetInstance()33   static std::shared_ptr<TaskStream> GetInstance() {
34     static const std::shared_ptr<TaskStream> instance = std::make_shared<TaskStream>();
35     return instance;
36   }
37 
set_gen_stream_list(const std::vector<rtStream_t> & stream_list)38   void set_gen_stream_list(const std::vector<rtStream_t> &stream_list) { gen_stream_list_ = stream_list; }
set_run_stream_list(const std::vector<rtStream_t> & stream_list)39   void set_run_stream_list(const std::vector<rtStream_t> &stream_list) { run_stream_list_ = stream_list; }
SetGenStreamIndex(uint32_t stream_id,uint32_t index)40   void SetGenStreamIndex(uint32_t stream_id, uint32_t index) { gen_stream_index_map_[stream_id] = index; }
GetGenStreamIndexMap()41   std::unordered_map<uint32_t, uint32_t> GetGenStreamIndexMap() { return gen_stream_index_map_; }
GetGenStreamIndex(uint32_t stream_id)42   uint32_t GetGenStreamIndex(uint32_t stream_id) {
43     auto iter = gen_stream_index_map_.find(stream_id);
44     if (iter == gen_stream_index_map_.end()) {
45       MS_LOG(EXCEPTION) << "Parameter stream_id not in gen_stream_index_map_, id: " << stream_id;
46     }
47     return iter->second;
48   }
gen_stream_list()49   const std::vector<rtStream_t> &gen_stream_list() const { return gen_stream_list_; }
run_stream_list()50   const std::vector<rtStream_t> &run_stream_list() const { return run_stream_list_; }
51 
52  private:
53   std::vector<rtStream_t> gen_stream_list_;
54   std::vector<rtStream_t> run_stream_list_;
55   std::unordered_map<uint32_t, uint32_t> gen_stream_index_map_;
56 };
57 }  // namespace kernel
58 }  // namespace mindspore
59 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TASK_STREAM_H_
60