• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_
19 #include <map>
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <functional>
25 #include "mindspore/core/ir/anf.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "backend/optimizer/mem_reuse/mem_reuse.h"
28 #include "backend/kernel_compiler/common_utils.h"
29 namespace mindspore {
30 namespace memreuse {
31 constexpr auto kSplitC = '/';
32 class MemReuseChecker {
33  public:
34   static MemReuseChecker &GetInstance();
35   MemReuseChecker(const MemReuseChecker &) = delete;
36   MemReuseChecker &operator=(const MemReuseChecker &) = delete;
37   void CheckSignalOps(const CNodePtr &c_node);
38   void CheckWorkSpace(const std::vector<size_t> &max_list);
39   void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx);
40   bool CheckGraphOutputAssigned(const session::KernelGraph *graph);
41   void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list,
42                        const KernelGraph *graph);
43   int64_t CalculOriStatic(const KernelGraph *graph) const;
44   int64_t CalculOriInput(const KernelGraph *graph) const;
45   int64_t CalculOriValue(const KernelGraph *graph) const;
46   int64_t CalculOriDy(const KernelGraph *graph) const;
47   int64_t CalculOriWk(const KernelGraph *graph) const;
48   std::string GetSplitName(const std::string &scope_name) const;
49   int GetTensorIdx(const void *in) const;
50   void SetMembuInfos(const KernelDef *op_def, const std::vector<MembufPtr> &membuf_ptr_list);
51   void SetTesnorFromAndToInfo(const KernelDef *op_def);
52   void ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx);
53   void ExportNormalOpIr(const std::vector<CNodePtr> &cnodes);
54   void ExportNormalTensorIR(std::ofstream &ofs);
55   void CheckNormalIR(const session::KernelGraph *graph);
56   void ExportMembufInfoIR();
57   void ExportEachMembufInfo(std::ofstream &ofs);
58   void SetAddNewMembuInfos(const KernelDef *op_def, const std::vector<MembufPtr> &membuf_ptr_list, size_t op_idx);
59   void ExportAddNewMmebufIR();
set_kernel_front_map(const std::map<KernelDefPtr,std::set<KernelDefPtr>> & kernel_front_map)60   void set_kernel_front_map(const std::map<KernelDefPtr, std::set<KernelDefPtr>> &kernel_front_map) {
61     kernel_front_map_ = kernel_front_map;
62   }
63   void ExportKernelDependence();
64 
65  private:
66   MemReuseChecker() = default;
~MemReuseChecker()67   ~MemReuseChecker() {}
68   bool IsAddNewMembuf_ = false;
69   size_t total_re_wkspe_size_checker_{0};
70   std::vector<std::vector<MembufPtr>> membuf_all_infos_;
71   std::vector<const void *> nor_output_tensors_;
72   std::vector<size_t> nor_tensor_sizes_;
73   std::vector<const void *> nor_input_tensors_;
74   std::map<const void *, size_t> ptr_idx_;
75   std::map<const void *, size_t> ptr_refs_;
76   std::map<void *, std::vector<const void *>> node_ins_;
77   std::map<void *, std::vector<const void *>> node_ous_;
78   std::vector<std::vector<MembufPtr>> add_new_mem_infos_;
79   std::vector<std::string> add_new_names_;
80   std::vector<size_t> add_new_op_indxs_;
81   std::vector<uint32_t> add_new_stream_ids_;
82   std::vector<std::string> all_split_names_;
83   std::map<int, std::vector<string>> tensor_from_;
84   std::map<int, std::vector<string>> tensor_to_;
85   std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
86   int64_t total_ori_static_size_ = 0;
87   int64_t total_ori_input_size_ = 0;
88   int64_t total_ori_value_size_ = 0;
89   int64_t total_ori_dy_size_ = 0;
90   int64_t total_ori_wkspace_size_ = 0;
91 };
92 }  // namespace memreuse
93 }  // namespace mindspore
94 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_CHECKER_H_
95