1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ 19 #include <vector> 20 #include <map> 21 #include <string> 22 #include <memory> 23 #include <set> 24 25 namespace mindspore { 26 namespace memreuse { 27 enum RefCountType { kDynamicRefCount, kStaticRefCount }; 28 enum NodeType { kCommonNode, kCommunicationNode }; 29 enum KernelRefType { kCommon, kRefNodeInput, kRefNodeOutput, kCommNotReuse, kCommReuse, kSummary }; 30 static constexpr int kInitIndex = -1; 31 class KernelRefCount { 32 public: 33 uint32_t stream_id_; 34 int ref_count_; 35 // used by dynamic memory pool, it will be reseted by ref_count_ when one minibatch end 36 int ref_count_dynamic_use_; 37 size_t offset_; 38 size_t size_; 39 int index_; 40 KernelRefType type_; 41 // remember to reset offset KernelRefCount()42 KernelRefCount() 43 : stream_id_(0), 44 ref_count_(0), 45 ref_count_dynamic_use_(0), 46 offset_(0), 47 size_(0), 48 index_(kInitIndex), 49 type_(kCommon), 50 reftype_(kStaticRefCount) {} 51 ~KernelRefCount() = default; 52 void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); set_reftype(RefCountType reftype)53 void set_reftype(RefCountType reftype) { reftype_ = reftype; } reftype()54 RefCountType reftype() const { return reftype_; } stream_id()55 uint32_t stream_id() const { return stream_id_; } 56 57 private: 58 RefCountType reftype_; 59 }; 60 using KernelRefCountPtr = std::shared_ptr<KernelRefCount>; 61 using KernelRefCountPtrList = std::vector<KernelRefCountPtr>; 62 // the ptr of every kernel to be key 63 using KernelKey = void *; 64 using KernelMap = std::map<KernelKey, std::vector<KernelRefCountPtr>>; 65 66 class KernelDef { 67 public: 68 KernelMap inputs_; 69 KernelMap outputs_; 70 KernelMap wk_space_; 71 NodeType type_ = kCommonNode; 72 KernelDef() = default; 73 ~KernelDef() = default; set_input_refs(const KernelRefCountPtrList & kernelRefPtrList)74 void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } set_output_refs(const KernelRefCountPtrList & kernelRefPtrList)75 void set_output_refs(const KernelRefCountPtrList &kernelRefPtrList) { output_refs_ = kernelRefPtrList; } input_refs()76 KernelRefCountPtrList input_refs() const { return input_refs_; } output_refs()77 KernelRefCountPtrList output_refs() const { return output_refs_; } 78 std::vector<int> GetInputRefIndexs() const; 79 std::vector<int> GetOutputRefIndexs() const; 80 std::vector<int> GetWorkspaceRefIndexs() const; set_stream_id(uint32_t stream_id)81 void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; } stream_id()82 uint32_t stream_id() const { return stream_id_; } set_kernel_name(const std::string & kernel_name)83 void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } kernel_name()84 std::string kernel_name() const { return kernel_name_; } set_scope_full_name(const std::string & scop_name)85 void set_scope_full_name(const std::string &scop_name) { scop_full_name_ = scop_name; } scope_full_name()86 std::string scope_full_name() const { return scop_full_name_; } InsertInputKernel(const std::shared_ptr<KernelDef> & input_kernel)87 void InsertInputKernel(const std::shared_ptr<KernelDef> &input_kernel) { input_kernels_.insert(input_kernel); } input_kernels()88 const std::set<std::shared_ptr<KernelDef>> &input_kernels() { return input_kernels_; } 89 90 private: 91 std::string scop_full_name_; 92 std::string kernel_name_; 93 uint32_t stream_id_{0}; 94 KernelRefCountPtrList input_refs_; 95 KernelRefCountPtrList output_refs_; 96 std::set<std::shared_ptr<KernelDef>> input_kernels_; 97 }; 98 using KernelDefPtr = std::shared_ptr<KernelDef>; 99 } // namespace memreuse 100 } // namespace mindspore 101 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ 102