1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_ 19 #include <map> 20 #include <memory> 21 #include <unordered_map> 22 #include <vector> 23 #include "backend/optimizer/mem_reuse/kernel_refcount.h" 24 #include "backend/session/anf_runtime_algorithm.h" 25 #include "backend/session/kernel_graph.h" 26 #include "backend/kernel_compiler/tbe/tbe_utils.h" 27 #include "utils/ms_context.h" 28 using mindspore::kernel::tbe::TbeUtils; 29 namespace mindspore { 30 namespace memreuse { 31 static constexpr int kMaxRefCount = 9999; 32 static constexpr size_t kDefaultMemAlignSize = 512; 33 static constexpr size_t kAttAlignSize = 31; 34 static constexpr int kInvalidIndex = -2; 35 36 using KernelDefPtrMaps = std::vector<mindspore::memreuse::KernelDefPtr>; 37 using KernelRefs = std::map<KernelKey, KernelRefCountPtrList>; 38 39 using KernelGraph = mindspore::session::KernelGraph; 40 41 class MemReuseUtil { 42 public: 43 KernelRefCountPtrList total_refs_list_; 44 KernelRefCountPtrList total_wk_ref_list_; MemReuseUtil()45 MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {} ~MemReuseUtil()46 ~MemReuseUtil() { 47 if (graph_ != nullptr) { 48 graph_ = nullptr; 49 } 50 MS_LOG(INFO) << "Total Dynamic Memory Size: " << total_dy_size_; 51 MS_LOG(INFO) << "Total WorkSpace Memory Size: " << total_workspace_size_; 52 MS_LOG(INFO) << "Total Reused WorkSpace Memory Size: " << total_reuseworkspace_size_; 53 } 54 55 void SetAllInfo(const KernelGraph *graph); 56 bool InitDynamicOutputKernelRef(); 57 bool InitDynamicWorkspaceKernelRef(); 58 bool InitDynamicKernelRef(const KernelGraph *graph); 59 void SetWorkSpaceList(); 60 void SetKernelDefMap(); 61 void SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); 62 void SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); 63 void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); 64 void SetKernelDefInputs(); 65 void SetReuseRefCount(); 66 #ifndef ENABLE_SECURITY 67 void SetSummaryNodesRefCount(); 68 #endif 69 void SetRefNodesInputRefCount(); 70 // Set the reference count of graph output specially. 71 void SetGraphOutputRefCount(); 72 // Reset the dynamic used reference count by ref_count_. 73 void ResetDynamicUsedRefCount(); 74 75 KernelRefCountPtr GetRef(const AnfNodePtr &node, size_t output_idx); 76 KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); total_refs_list()77 KernelRefCountPtrList total_refs_list() const { return total_refs_list_; } total_wk_ref_list()78 KernelRefCountPtrList total_wk_ref_list() const { return total_wk_ref_list_; } kernel_def_ptr_list()79 KernelDefPtrMaps kernel_def_ptr_list() const { return kernel_def_ptr_list_; } max_workspace_size()80 int max_workspace_size() const { return max_workspace_size_; } max_workspace_list()81 std::vector<size_t> max_workspace_list() const { return max_workspace_list_; } set_total_refs_list(const KernelRefCountPtrList & total_refs_list)82 void set_total_refs_list(const KernelRefCountPtrList &total_refs_list) { total_refs_list_ = total_refs_list; } set_kernel_def_ptr_list(const KernelDefPtrMaps & kernel_def_ptr_list)83 void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) { 84 kernel_def_ptr_list_ = kernel_def_ptr_list; 85 } set_mem_base(uint8_t * mem_base)86 void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; } 87 uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; 88 uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; is_all_nop_node()89 bool is_all_nop_node() const { return is_all_nop_node_; } 90 session::KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &node, size_t i, bool visit_nop_node); 91 92 private: 93 KernelRefs kernel_output_refs_; 94 KernelRefs kernel_workspace_refs_; 95 int util_index_; 96 const KernelGraph *graph_; 97 bool is_all_nop_node_; 98 KernelRefCountPtrList ref_list_; 99 KernelDefPtrMaps kernel_def_ptr_list_; 100 KernelRefCountPtrList last_ref_list_; 101 int max_workspace_size_ = 0; 102 std::vector<size_t> max_workspace_list_; 103 size_t total_dy_size_ = 0; 104 size_t total_workspace_size_ = 0; 105 size_t total_reuseworkspace_size_ = 0; 106 uint8_t *mem_base_{nullptr}; 107 // kernel_map_: key is the AnfNodePtr addr, value is the KernelDef 108 std::map<KernelKey, KernelDefPtr> kernel_map_; 109 110 bool enable_visit_kernel_cache_{false}; 111 112 std::unordered_map<AnfNodePtr, session::KernelWithIndex> visit_kernel_with_return_type_in0pos_cache_; 113 std::unordered_map<AnfNodePtr, session::KernelWithIndex> visit_kernel_with_return_type_in0pos_skip_nop_cache_; 114 }; 115 using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>; 116 117 enum Status { kUnused, kReused }; 118 enum MemType { kNew, kInStreamReuse, kBetweenStreamReuse, kKernelDependenceReuse }; 119 class Membuf { 120 public: 121 Membuf() = default; Membuf(Status status,size_t size,size_t offset,int index,MemType type,const KernelDefPtr & used_kernel)122 Membuf(Status status, size_t size, size_t offset, int index, MemType type, const KernelDefPtr &used_kernel) 123 : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} 124 ~Membuf() = default; 125 // Memory block status flags 126 Status status_ = kUnused; 127 size_t size_{0}; 128 size_t offset_{0}; 129 // Store the tensor index stored in this memory block at a certain moment 130 int index_{0}; 131 MemType type_{kNew}; 132 KernelDefPtr used_kernel_; 133 }; 134 using MembufPtr = std::shared_ptr<Membuf>; 135 136 } // namespace memreuse 137 } // namespace mindspore 138 139 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_ 140