• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_
18 #define MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_
19 
20 #include <memory>
21 #include <map>
22 #include <utility>
23 #include <string>
24 #include "utils/info.h"
25 #include "utils/os.h"
26 
27 namespace mindspore {
28 enum Axis : int {
29   N = 0,
30   C,
31   H,
32   W,
33 };
34 
35 // Cache some runtime information which not be changed.
36 class RuntimeCache {
37  public:
get_prev_node_output(size_t index)38   std::pair<AnfNodePtr, size_t> get_prev_node_output(size_t index) {
39     auto it = prev_node_output_map_.find(index);
40     if (it != prev_node_output_map_.end()) {
41       MS_EXCEPTION_IF_NULL(it->second.first.lock());
42       return std::make_pair(it->second.first.lock(), it->second.second);
43     } else {
44       return std::pair<AnfNodePtr, size_t>();
45     }
46   }
47 
set_prev_node_output(size_t index,std::pair<AnfNodePtr,size_t> output)48   void set_prev_node_output(size_t index, std::pair<AnfNodePtr, size_t> output) {
49     auto pr = std::make_pair(index, output);
50     (void)prev_node_output_map_.insert(pr);
51   }
52 
GetPrevOutputs()53   std::map<size_t, std::pair<AnfNodeWeakPtr, size_t>> GetPrevOutputs() const { return prev_node_output_map_; }
54 
update_prev_node_output(size_t index,const std::pair<AnfNodePtr,size_t> & output)55   void update_prev_node_output(size_t index, const std::pair<AnfNodePtr, size_t> &output) {
56     if (prev_node_output_map_.find(index) == prev_node_output_map_.end()) {
57       MS_LOG(DEBUG) << "Index:" << index << " not in prev node map";
58       return;
59     }
60     prev_node_output_map_[index] = output;
61   }
62 
reset()63   void reset() {
64     MS_EXCEPTION_IF_CHECK_FAIL(!is_valid_, "this runtime cache is valid, can't reset!!!!");
65     prev_node_output_map_.clear();
66     device_target_.clear();
67     output_tensor_num_ = -1;
68     is_real_kernel_ = Uncached;
69   }
70 
device_target()71   std::string device_target() const { return device_target_; }
72 
set_device_target(const std::string & target)73   void set_device_target(const std::string &target) { device_target_ = target; }
is_valid()74   bool is_valid() const { return is_valid_; }
set_is_valid(bool is_vaild)75   void set_is_valid(bool is_vaild) { is_valid_ = is_vaild; }
set_output_tensor_num(const ssize_t output_tensor_num)76   void set_output_tensor_num(const ssize_t output_tensor_num) { output_tensor_num_ = output_tensor_num; }
output_tensor_num()77   ssize_t output_tensor_num() const { return output_tensor_num_; }
set_real_kernel(CacheBool b)78   void set_real_kernel(CacheBool b) { is_real_kernel_ = b; }
is_real_kernel()79   CacheBool is_real_kernel() const { return is_real_kernel_; }
80 
81  private:
82   bool is_valid_{false};
83   std::map<size_t, std::pair<AnfNodeWeakPtr, size_t>> prev_node_output_map_;
84   std::string device_target_;
85   ssize_t output_tensor_num_ = -1;
86   CacheBool is_real_kernel_ = Uncached;
87 };
88 // Interface for device kernel program information.
89 class KernelInfoDevice {
90  public:
91   class RuntimeCacheScope {
92    public:
RuntimeCacheScope(RuntimeCache & base,std::mutex & mu)93     RuntimeCacheScope(RuntimeCache &base, std::mutex &mu) : runtime_cache_(base), mu_(mu) { mu_.lock(); }
94     RuntimeCacheScope(const RuntimeCacheScope &other) = delete;
95     RuntimeCacheScope operator=(const RuntimeCacheScope &other) = delete;
~RuntimeCacheScope()96     ~RuntimeCacheScope() { mu_.unlock(); }
runtime_cache()97     RuntimeCache &runtime_cache() { return runtime_cache_; }
98 
99    private:
100     RuntimeCache &runtime_cache_;
101     std::mutex &mu_;
102   };
103   // If kernel program was built and build info is set.
104   virtual bool has_build_info() const = 0;
105 
runtime_cache()106   RuntimeCacheScope runtime_cache() { return RuntimeCacheScope(runtime_cache_, mu_); }
107 
~KernelInfoDevice()108   virtual ~KernelInfoDevice() {}
109 
110  private:
111   RuntimeCache runtime_cache_;
112   std::mutex mu_;
113 };
114 using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
115 }  // namespace mindspore
116 
117 #endif  // MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_
118