• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_
19 #include <vector>
20 #include <map>
21 #include <string>
22 #include <memory>
23 #include <set>
24 
25 namespace mindspore {
26 namespace memreuse {
27 // todo: delete with kernel-runtime
28 enum RefCountType { kDynamicRefCount, kStaticRefCount };
29 enum NodeType { kCommonNode, kCommunicationNode };
30 enum KernelRefType { kCommon, kRefNodeInput, kRefNodeOutput, kCommNotReuse, kCommReuse, kSummary };
31 static constexpr int kInitIndex = -1;
32 class KernelRefCount {
33  public:
34   uint32_t stream_id_;
35   int ref_count_;
36   // used by dynamic memory pool, it will be reset by ref_count_ when one minibatch end
37   int ref_count_dynamic_use_;
38   size_t offset_;
39   size_t size_;
40   int index_;
41   KernelRefType type_;
42   // remember to reset offset
KernelRefCount()43   KernelRefCount()
44       : stream_id_(0),
45         ref_count_(0),
46         ref_count_dynamic_use_(0),
47         offset_(0),
48         size_(0),
49         index_(kInitIndex),
50         type_(kCommon),
51         reftype_(kStaticRefCount) {}
52   ~KernelRefCount() = default;
53   void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype);
set_reftype(RefCountType reftype)54   void set_reftype(RefCountType reftype) { reftype_ = reftype; }
reftype()55   RefCountType reftype() const { return reftype_; }
stream_id()56   uint32_t stream_id() const { return stream_id_; }
57 
58  private:
59   RefCountType reftype_;
60 };
61 using KernelRefCountPtr = std::shared_ptr<KernelRefCount>;
62 using KernelRefCountPtrList = std::vector<KernelRefCountPtr>;
63 // the ptr of every kernel to be key
64 using KernelKey = void *;
65 using KernelMap = std::map<KernelKey, std::vector<KernelRefCountPtr>>;
66 
67 class KernelDef {
68  public:
69   KernelMap inputs_;
70   KernelMap outputs_;
71   KernelMap wk_space_;
72   NodeType type_ = kCommonNode;
73   KernelDef() = default;
74   ~KernelDef() = default;
set_input_refs(const KernelRefCountPtrList & kernelRefPtrList)75   void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; }
set_output_refs(const KernelRefCountPtrList & kernelRefPtrList)76   void set_output_refs(const KernelRefCountPtrList &kernelRefPtrList) { output_refs_ = kernelRefPtrList; }
input_refs()77   KernelRefCountPtrList input_refs() const { return input_refs_; }
output_refs()78   KernelRefCountPtrList output_refs() const { return output_refs_; }
79   std::vector<int> GetInputRefIndexs() const;
80   std::vector<int> GetOutputRefIndexs() const;
81   std::vector<int> GetWorkspaceRefIndexs() const;
set_stream_id(uint32_t stream_id)82   void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; }
stream_id()83   uint32_t stream_id() const { return stream_id_; }
set_kernel_name(const std::string & kernel_name)84   void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
kernel_name()85   std::string kernel_name() const { return kernel_name_; }
set_scope_full_name(const std::string & scop_name)86   void set_scope_full_name(const std::string &scop_name) { scop_full_name_ = scop_name; }
scope_full_name()87   std::string scope_full_name() const { return scop_full_name_; }
InsertInputKernel(const std::shared_ptr<KernelDef> & input_kernel)88   void InsertInputKernel(const std::shared_ptr<KernelDef> &input_kernel) { input_kernels_.insert(input_kernel); }
input_kernels()89   const std::set<std::shared_ptr<KernelDef>> &input_kernels() const { return input_kernels_; }
90 
91  private:
92   std::string scop_full_name_;
93   std::string kernel_name_;
94   uint32_t stream_id_{0};
95   KernelRefCountPtrList input_refs_;
96   KernelRefCountPtrList output_refs_;
97   std::set<std::shared_ptr<KernelDef>> input_kernels_;
98 };
99 using KernelDefPtr = std::shared_ptr<KernelDef>;
100 }  // namespace memreuse
101 }  // namespace mindspore
102 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_
103