• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_
19 #include <map>
20 #include <memory>
21 #include <unordered_map>
22 #include <vector>
23 #include "backend/optimizer/mem_reuse/kernel_refcount.h"
24 #include "backend/session/anf_runtime_algorithm.h"
25 #include "backend/session/kernel_graph.h"
26 #include "backend/kernel_compiler/tbe/tbe_utils.h"
27 #include "utils/ms_context.h"
28 using mindspore::kernel::tbe::TbeUtils;
29 namespace mindspore {
30 namespace memreuse {
31 static constexpr int kMaxRefCount = 9999;
32 static constexpr size_t kDefaultMemAlignSize = 512;
33 static constexpr size_t kAttAlignSize = 31;
34 static constexpr int kInvalidIndex = -2;
35 
36 using KernelDefPtrMaps = std::vector<mindspore::memreuse::KernelDefPtr>;
37 using KernelRefs = std::map<KernelKey, KernelRefCountPtrList>;
38 
39 using KernelGraph = mindspore::session::KernelGraph;
40 
41 class MemReuseUtil {
42  public:
43   KernelRefCountPtrList total_refs_list_;
44   KernelRefCountPtrList total_wk_ref_list_;
MemReuseUtil()45   MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {}
~MemReuseUtil()46   ~MemReuseUtil() {
47     if (graph_ != nullptr) {
48       graph_ = nullptr;
49     }
50     MS_LOG(INFO) << "Total Dynamic Memory Size:  " << total_dy_size_;
51     MS_LOG(INFO) << "Total WorkSpace Memory Size: " << total_workspace_size_;
52     MS_LOG(INFO) << "Total Reused WorkSpace Memory Size: " << total_reuseworkspace_size_;
53   }
54 
55   void SetAllInfo(const KernelGraph *graph);
56   bool InitDynamicOutputKernelRef();
57   bool InitDynamicWorkspaceKernelRef();
58   bool InitDynamicKernelRef(const KernelGraph *graph);
59   void SetWorkSpaceList();
60   void SetKernelDefMap();
61   void SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
62   void SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
63   void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
64   void SetKernelDefInputs();
65   void SetReuseRefCount();
66 #ifndef ENABLE_SECURITY
67   void SetSummaryNodesRefCount();
68 #endif
69   void SetRefNodesInputRefCount();
70   // Set the reference count of graph output specially.
71   void SetGraphOutputRefCount();
72   // Reset the dynamic used reference count by ref_count_.
73   void ResetDynamicUsedRefCount();
74 
75   KernelRefCountPtr GetRef(const AnfNodePtr &node, size_t output_idx);
76   KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx);
total_refs_list()77   KernelRefCountPtrList total_refs_list() const { return total_refs_list_; }
total_wk_ref_list()78   KernelRefCountPtrList total_wk_ref_list() const { return total_wk_ref_list_; }
kernel_def_ptr_list()79   KernelDefPtrMaps kernel_def_ptr_list() const { return kernel_def_ptr_list_; }
max_workspace_size()80   int max_workspace_size() const { return max_workspace_size_; }
max_workspace_list()81   std::vector<size_t> max_workspace_list() const { return max_workspace_list_; }
set_total_refs_list(const KernelRefCountPtrList & total_refs_list)82   void set_total_refs_list(const KernelRefCountPtrList &total_refs_list) { total_refs_list_ = total_refs_list; }
set_kernel_def_ptr_list(const KernelDefPtrMaps & kernel_def_ptr_list)83   void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) {
84     kernel_def_ptr_list_ = kernel_def_ptr_list;
85   }
set_mem_base(uint8_t * mem_base)86   void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; }
87   uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const;
88   uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const;
is_all_nop_node()89   bool is_all_nop_node() const { return is_all_nop_node_; }
90   session::KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &node, size_t i, bool visit_nop_node);
91 
92  private:
93   KernelRefs kernel_output_refs_;
94   KernelRefs kernel_workspace_refs_;
95   int util_index_;
96   const KernelGraph *graph_;
97   bool is_all_nop_node_;
98   KernelRefCountPtrList ref_list_;
99   KernelDefPtrMaps kernel_def_ptr_list_;
100   KernelRefCountPtrList last_ref_list_;
101   int max_workspace_size_ = 0;
102   std::vector<size_t> max_workspace_list_;
103   size_t total_dy_size_ = 0;
104   size_t total_workspace_size_ = 0;
105   size_t total_reuseworkspace_size_ = 0;
106   uint8_t *mem_base_{nullptr};
107   // kernel_map_: key is the AnfNodePtr addr, value is the KernelDef
108   std::map<KernelKey, KernelDefPtr> kernel_map_;
109 
110   bool enable_visit_kernel_cache_{false};
111 
112   std::unordered_map<AnfNodePtr, session::KernelWithIndex> visit_kernel_with_return_type_in0pos_cache_;
113   std::unordered_map<AnfNodePtr, session::KernelWithIndex> visit_kernel_with_return_type_in0pos_skip_nop_cache_;
114 };
115 using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>;
116 
117 enum Status { kUnused, kReused };
118 enum MemType { kNew, kInStreamReuse, kBetweenStreamReuse, kKernelDependenceReuse };
119 class Membuf {
120  public:
121   Membuf() = default;
Membuf(Status status,size_t size,size_t offset,int index,MemType type,const KernelDefPtr & used_kernel)122   Membuf(Status status, size_t size, size_t offset, int index, MemType type, const KernelDefPtr &used_kernel)
123       : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {}
124   ~Membuf() = default;
125   // Memory block status flags
126   Status status_ = kUnused;
127   size_t size_{0};
128   size_t offset_{0};
129   // Store the tensor index stored in this memory block at a certain moment
130   int index_{0};
131   MemType type_{kNew};
132   KernelDefPtr used_kernel_;
133 };
134 using MembufPtr = std::shared_ptr<Membuf>;
135 
136 }  // namespace memreuse
137 }  // namespace mindspore
138 
139 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_
140