• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_
19 #include <vector>
20 #include <map>
21 #include <string>
22 #include <memory>
23 #include <set>
24 
25 namespace mindspore {
26 namespace memreuse {
27 enum RefCountType { kDynamicRefCount, kStaticRefCount };
28 enum NodeType { kCommonNode, kCommunicationNode };
29 enum KernelRefType { kCommon, kRefNodeInput, kRefNodeOutput, kCommNotReuse, kCommReuse, kSummary };
30 static constexpr int kInitIndex = -1;
31 class KernelRefCount {
32  public:
33   uint32_t stream_id_;
34   int ref_count_;
35   // used by dynamic memory pool, it will be reseted by ref_count_ when one minibatch end
36   int ref_count_dynamic_use_;
37   size_t offset_;
38   size_t size_;
39   int index_;
40   KernelRefType type_;
41   // remember to reset offset
KernelRefCount()42   KernelRefCount()
43       : stream_id_(0),
44         ref_count_(0),
45         ref_count_dynamic_use_(0),
46         offset_(0),
47         size_(0),
48         index_(kInitIndex),
49         type_(kCommon),
50         reftype_(kStaticRefCount) {}
51   ~KernelRefCount() = default;
52   void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype);
set_reftype(RefCountType reftype)53   void set_reftype(RefCountType reftype) { reftype_ = reftype; }
reftype()54   RefCountType reftype() const { return reftype_; }
stream_id()55   uint32_t stream_id() const { return stream_id_; }
56 
57  private:
58   RefCountType reftype_;
59 };
60 using KernelRefCountPtr = std::shared_ptr<KernelRefCount>;
61 using KernelRefCountPtrList = std::vector<KernelRefCountPtr>;
62 // the ptr of every kernel to be key
63 using KernelKey = void *;
64 using KernelMap = std::map<KernelKey, std::vector<KernelRefCountPtr>>;
65 
66 class KernelDef {
67  public:
68   KernelMap inputs_;
69   KernelMap outputs_;
70   KernelMap wk_space_;
71   NodeType type_ = kCommonNode;
72   KernelDef() = default;
73   ~KernelDef() = default;
set_input_refs(const KernelRefCountPtrList & kernelRefPtrList)74   void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; }
set_output_refs(const KernelRefCountPtrList & kernelRefPtrList)75   void set_output_refs(const KernelRefCountPtrList &kernelRefPtrList) { output_refs_ = kernelRefPtrList; }
input_refs()76   KernelRefCountPtrList input_refs() const { return input_refs_; }
output_refs()77   KernelRefCountPtrList output_refs() const { return output_refs_; }
78   std::vector<int> GetInputRefIndexs() const;
79   std::vector<int> GetOutputRefIndexs() const;
80   std::vector<int> GetWorkspaceRefIndexs() const;
set_stream_id(uint32_t stream_id)81   void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; }
stream_id()82   uint32_t stream_id() const { return stream_id_; }
set_kernel_name(const std::string & kernel_name)83   void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
kernel_name()84   std::string kernel_name() const { return kernel_name_; }
set_scope_full_name(const std::string & scop_name)85   void set_scope_full_name(const std::string &scop_name) { scop_full_name_ = scop_name; }
scope_full_name()86   std::string scope_full_name() const { return scop_full_name_; }
InsertInputKernel(const std::shared_ptr<KernelDef> & input_kernel)87   void InsertInputKernel(const std::shared_ptr<KernelDef> &input_kernel) { input_kernels_.insert(input_kernel); }
input_kernels()88   const std::set<std::shared_ptr<KernelDef>> &input_kernels() { return input_kernels_; }
89 
90  private:
91   std::string scop_full_name_;
92   std::string kernel_name_;
93   uint32_t stream_id_{0};
94   KernelRefCountPtrList input_refs_;
95   KernelRefCountPtrList output_refs_;
96   std::set<std::shared_ptr<KernelDef>> input_kernels_;
97 };
98 using KernelDefPtr = std::shared_ptr<KernelDef>;
99 }  // namespace memreuse
100 }  // namespace mindspore
101 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_
102