1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_ 18 #define MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_ 19 20 #include <memory> 21 #include <map> 22 #include <utility> 23 #include <string> 24 #include "utils/info.h" 25 #include "utils/os.h" 26 27 namespace mindspore { 28 enum Axis : int { 29 N = 0, 30 C, 31 H, 32 W, 33 }; 34 35 // Cache some runtime information which not be changed. 36 class RuntimeCache { 37 public: get_prev_node_output(size_t index)38 std::pair<AnfNodePtr, size_t> get_prev_node_output(size_t index) { 39 auto it = prev_node_output_map_.find(index); 40 if (it != prev_node_output_map_.end()) { 41 MS_EXCEPTION_IF_NULL(it->second.first.lock()); 42 return std::make_pair(it->second.first.lock(), it->second.second); 43 } else { 44 return std::pair<AnfNodePtr, size_t>(); 45 } 46 } 47 set_prev_node_output(size_t index,std::pair<AnfNodePtr,size_t> output)48 void set_prev_node_output(size_t index, std::pair<AnfNodePtr, size_t> output) { 49 auto pr = std::make_pair(index, output); 50 (void)prev_node_output_map_.insert(pr); 51 } 52 GetPrevOutputs()53 std::map<size_t, std::pair<AnfNodeWeakPtr, size_t>> GetPrevOutputs() const { return prev_node_output_map_; } 54 update_prev_node_output(size_t index,const std::pair<AnfNodePtr,size_t> & output)55 void update_prev_node_output(size_t index, const std::pair<AnfNodePtr, size_t> &output) { 56 if (prev_node_output_map_.find(index) == prev_node_output_map_.end()) { 57 MS_LOG(DEBUG) << "Index:" << index << " not in prev node map"; 58 return; 59 } 60 prev_node_output_map_[index] = output; 61 } 62 reset()63 void reset() { 64 MS_EXCEPTION_IF_CHECK_FAIL(!is_valid_, "this runtime cache is valid, can't reset!!!!"); 65 prev_node_output_map_.clear(); 66 device_target_.clear(); 67 output_tensor_num_ = -1; 68 is_real_kernel_ = Uncached; 69 } 70 device_target()71 std::string device_target() const { return device_target_; } 72 set_device_target(const std::string & target)73 void set_device_target(const std::string &target) { device_target_ = target; } is_valid()74 bool is_valid() const { return is_valid_; } set_is_valid(bool is_vaild)75 void set_is_valid(bool is_vaild) { is_valid_ = is_vaild; } set_output_tensor_num(const ssize_t output_tensor_num)76 void set_output_tensor_num(const ssize_t output_tensor_num) { output_tensor_num_ = output_tensor_num; } output_tensor_num()77 ssize_t output_tensor_num() const { return output_tensor_num_; } set_real_kernel(CacheBool b)78 void set_real_kernel(CacheBool b) { is_real_kernel_ = b; } is_real_kernel()79 CacheBool is_real_kernel() const { return is_real_kernel_; } 80 81 private: 82 bool is_valid_{false}; 83 std::map<size_t, std::pair<AnfNodeWeakPtr, size_t>> prev_node_output_map_; 84 std::string device_target_; 85 ssize_t output_tensor_num_ = -1; 86 CacheBool is_real_kernel_ = Uncached; 87 }; 88 // Interface for device kernel program information. 89 class KernelInfoDevice { 90 public: 91 class RuntimeCacheScope { 92 public: RuntimeCacheScope(RuntimeCache & base,std::mutex & mu)93 RuntimeCacheScope(RuntimeCache &base, std::mutex &mu) : runtime_cache_(base), mu_(mu) { mu_.lock(); } 94 RuntimeCacheScope(const RuntimeCacheScope &other) = delete; 95 RuntimeCacheScope operator=(const RuntimeCacheScope &other) = delete; ~RuntimeCacheScope()96 ~RuntimeCacheScope() { mu_.unlock(); } runtime_cache()97 RuntimeCache &runtime_cache() { return runtime_cache_; } 98 99 private: 100 RuntimeCache &runtime_cache_; 101 std::mutex &mu_; 102 }; 103 // If kernel program was built and build info is set. 104 virtual bool has_build_info() const = 0; 105 runtime_cache()106 RuntimeCacheScope runtime_cache() { return RuntimeCacheScope(runtime_cache_, mu_); } 107 ~KernelInfoDevice()108 virtual ~KernelInfoDevice() {} 109 110 private: 111 RuntimeCache runtime_cache_; 112 std::mutex mu_; 113 }; 114 using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>; 115 } // namespace mindspore 116 117 #endif // MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_ 118