1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_ 19 #include <map> 20 #include <set> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include <functional> 25 #include "mindspore/core/ir/anf.h" 26 #include "backend/session/anf_runtime_algorithm.h" 27 #include "backend/optimizer/mem_reuse/mem_reuse.h" 28 #include "backend/kernel_compiler/common_utils.h" 29 namespace mindspore { 30 namespace memreuse { 31 constexpr auto kSplitC = '/'; 32 class MemReuseChecker { 33 public: 34 static MemReuseChecker &GetInstance(); 35 MemReuseChecker(const MemReuseChecker &) = delete; 36 MemReuseChecker &operator=(const MemReuseChecker &) = delete; 37 void CheckSignalOps(const CNodePtr &c_node); 38 void CheckWorkSpace(const std::vector<size_t> &max_list); 39 void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); 40 bool CheckGraphOutputAssigned(const session::KernelGraph *graph); 41 void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, 42 const KernelGraph *graph); 43 int64_t CalculOriStatic(const KernelGraph *graph) const; 44 int64_t CalculOriInput(const KernelGraph *graph) const; 45 int64_t CalculOriValue(const KernelGraph *graph) const; 46 int64_t CalculOriDy(const KernelGraph *graph) const; 47 int64_t CalculOriWk(const KernelGraph *graph) const; 48 std::string GetSplitName(const std::string &scope_name) const; 49 int GetTensorIdx(const void *in) const; 50 void SetMembuInfos(const KernelDef *op_def, const std::vector<MembufPtr> &membuf_ptr_list); 51 void SetTesnorFromAndToInfo(const KernelDef *op_def); 52 void ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx); 53 void ExportNormalOpIr(const std::vector<CNodePtr> &cnodes); 54 void ExportNormalTensorIR(std::ofstream &ofs); 55 void CheckNormalIR(const session::KernelGraph *graph); 56 void ExportMembufInfoIR(); 57 void ExportEachMembufInfo(std::ofstream &ofs); 58 void SetAddNewMembuInfos(const KernelDef *op_def, const std::vector<MembufPtr> &membuf_ptr_list, size_t op_idx); 59 void ExportAddNewMmebufIR(); set_kernel_front_map(const std::map<KernelDefPtr,std::set<KernelDefPtr>> & kernel_front_map)60 void set_kernel_front_map(const std::map<KernelDefPtr, std::set<KernelDefPtr>> &kernel_front_map) { 61 kernel_front_map_ = kernel_front_map; 62 } 63 void ExportKernelDependence(); 64 65 private: 66 MemReuseChecker() = default; ~MemReuseChecker()67 ~MemReuseChecker() {} 68 bool IsAddNewMembuf_ = false; 69 size_t total_re_wkspe_size_checker_{0}; 70 std::vector<std::vector<MembufPtr>> membuf_all_infos_; 71 std::vector<const void *> nor_output_tensors_; 72 std::vector<size_t> nor_tensor_sizes_; 73 std::vector<const void *> nor_input_tensors_; 74 std::map<const void *, size_t> ptr_idx_; 75 std::map<const void *, size_t> ptr_refs_; 76 std::map<void *, std::vector<const void *>> node_ins_; 77 std::map<void *, std::vector<const void *>> node_ous_; 78 std::vector<std::vector<MembufPtr>> add_new_mem_infos_; 79 std::vector<std::string> add_new_names_; 80 std::vector<size_t> add_new_op_indxs_; 81 std::vector<uint32_t> add_new_stream_ids_; 82 std::vector<std::string> all_split_names_; 83 std::map<int, std::vector<string>> tensor_from_; 84 std::map<int, std::vector<string>> tensor_to_; 85 std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_; 86 int64_t total_ori_static_size_ = 0; 87 int64_t total_ori_input_size_ = 0; 88 int64_t total_ori_value_size_ = 0; 89 int64_t total_ori_dy_size_ = 0; 90 int64_t total_ori_wkspace_size_ = 0; 91 }; 92 } // namespace memreuse 93 } // namespace mindspore 94 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_ 95