1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ 19 #include <vector> 20 #include <map> 21 #include <string> 22 #include <memory> 23 #include <set> 24 25 namespace mindspore { 26 namespace memreuse { 27 // todo: delete with kernel-runtime 28 enum RefCountType { kDynamicRefCount, kStaticRefCount }; 29 enum NodeType { kCommonNode, kCommunicationNode }; 30 enum KernelRefType { kCommon, kRefNodeInput, kRefNodeOutput, kCommNotReuse, kCommReuse, kSummary }; 31 static constexpr int kInitIndex = -1; 32 class KernelRefCount { 33 public: 34 uint32_t stream_id_; 35 int ref_count_; 36 // used by dynamic memory pool, it will be reset by ref_count_ when one minibatch end 37 int ref_count_dynamic_use_; 38 size_t offset_; 39 size_t size_; 40 int index_; 41 KernelRefType type_; 42 // remember to reset offset KernelRefCount()43 KernelRefCount() 44 : stream_id_(0), 45 ref_count_(0), 46 ref_count_dynamic_use_(0), 47 offset_(0), 48 size_(0), 49 index_(kInitIndex), 50 type_(kCommon), 51 reftype_(kStaticRefCount) {} 52 ~KernelRefCount() = default; 53 void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); set_reftype(RefCountType reftype)54 void set_reftype(RefCountType reftype) { reftype_ = reftype; } reftype()55 RefCountType reftype() const { return reftype_; } stream_id()56 uint32_t stream_id() const { return stream_id_; } 57 58 private: 59 RefCountType reftype_; 60 }; 61 using KernelRefCountPtr = std::shared_ptr<KernelRefCount>; 62 using KernelRefCountPtrList = std::vector<KernelRefCountPtr>; 63 // the ptr of every kernel to be key 64 using KernelKey = void *; 65 using KernelMap = std::map<KernelKey, std::vector<KernelRefCountPtr>>; 66 67 class KernelDef { 68 public: 69 KernelMap inputs_; 70 KernelMap outputs_; 71 KernelMap wk_space_; 72 NodeType type_ = kCommonNode; 73 KernelDef() = default; 74 ~KernelDef() = default; set_input_refs(const KernelRefCountPtrList & kernelRefPtrList)75 void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } set_output_refs(const KernelRefCountPtrList & kernelRefPtrList)76 void set_output_refs(const KernelRefCountPtrList &kernelRefPtrList) { output_refs_ = kernelRefPtrList; } input_refs()77 KernelRefCountPtrList input_refs() const { return input_refs_; } output_refs()78 KernelRefCountPtrList output_refs() const { return output_refs_; } 79 std::vector<int> GetInputRefIndexs() const; 80 std::vector<int> GetOutputRefIndexs() const; 81 std::vector<int> GetWorkspaceRefIndexs() const; set_stream_id(uint32_t stream_id)82 void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; } stream_id()83 uint32_t stream_id() const { return stream_id_; } set_kernel_name(const std::string & kernel_name)84 void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } kernel_name()85 std::string kernel_name() const { return kernel_name_; } set_scope_full_name(const std::string & scop_name)86 void set_scope_full_name(const std::string &scop_name) { scop_full_name_ = scop_name; } scope_full_name()87 std::string scope_full_name() const { return scop_full_name_; } InsertInputKernel(const std::shared_ptr<KernelDef> & input_kernel)88 void InsertInputKernel(const std::shared_ptr<KernelDef> &input_kernel) { input_kernels_.insert(input_kernel); } input_kernels()89 const std::set<std::shared_ptr<KernelDef>> &input_kernels() const { return input_kernels_; } 90 91 private: 92 std::string scop_full_name_; 93 std::string kernel_name_; 94 uint32_t stream_id_{0}; 95 KernelRefCountPtrList input_refs_; 96 KernelRefCountPtrList output_refs_; 97 std::set<std::shared_ptr<KernelDef>> input_kernels_; 98 }; 99 using KernelDefPtr = std::shared_ptr<KernelDef>; 100 } // namespace memreuse 101 } // namespace mindspore 102 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_KERNEL_REFCOUNT_H_ 103