• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GRAPH_KERNEL_BUILD_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GRAPH_KERNEL_BUILD_H_
19 
20 #include <string>
21 #include <utility>
22 #include <vector>
23 #include <map>
24 #include <set>
25 #include "nlohmann/json.hpp"
26 #include "ir/anf.h"
27 #include "kernel/kernel.h"
28 #include "kernel/kash/kernel_pack.h"
29 #include "backend/common/session/kernel_build_client.h"
30 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
31 
32 namespace mindspore {
33 namespace kernel {
34 using graphkernel::GraphKernelJsonGenerator;
35 using JsonNodePair = std::pair<GraphKernelJsonGenerator, AnfNodePtr>;
36 
37 class BACKEND_EXPORT GraphKernelBuilder {
38  public:
39   GraphKernelBuilder() = default;
40   virtual ~GraphKernelBuilder() = default;
41 
42   virtual KernelPackPtr SearchKernelCache(const std::string &kernel_name);
43   virtual KernelPackPtr InsertKernelCache(const std::string &kernel_name);
44   virtual void LoadCache();
45 
46   virtual KernelBuildClient *GetClient() = 0;
47   virtual void SetKernelMod(const KernelPackPtr &kernel_pack, const GraphKernelJsonGenerator &json_generator,
48                             const AnfNodePtr &anf_node) = 0;
49   virtual void SaveJsonInfo(const string &kernel_name, const string &kernel_json) = 0;
50   virtual bool SingleOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) = 0;
51   virtual bool ParallelBuild(const std::vector<JsonNodePair> &build_args) = 0;
52 
53  protected:
54   std::vector<std::string> GetKernelJsonsByHashId(const std::vector<JsonNodePair> &build_args,
55                                                   const std::set<size_t> &fetched_ids);
56   std::vector<JsonNodePair> GetNotCachedKernels(const std::vector<JsonNodePair> &build_args);
57 
58   bool InsertToCache(const std::vector<JsonNodePair> &build_args);
59   bool HandleRepeatNodes();
60 
61   std::vector<JsonNodePair> repeat_nodes_;
62   nlohmann::json build_attrs_;
63   std::string CollectBuildAttrs();
64 };
65 
66 class KernelPool {
67  public:
68   class LockMng {
69    public:
LockMng(const int32_t fd,const char * function,const uint32_t line)70     explicit LockMng(const int32_t fd, const char *function, const uint32_t line) {
71       fd_ = fd;
72       calling_position_ = std::string(function) + ":" + std::to_string(line);
73       locked_ = TryLock();
74     }
75 
~LockMng()76     virtual ~LockMng() {
77       if (locked_) {
78         Unlock();
79       }
80     }
81 
82     bool locked_{false};
83 
84    private:
85     bool TryLock() const;
86     void Unlock() const noexcept;
87 
88     int32_t fd_{-1};
89     std::string calling_position_;
90   };
91 
92   KernelPool() = default;
~KernelPool()93   virtual ~KernelPool() {
94     // Close key file
95     if (fd_ != -1) {
96       (void)close(fd_);
97     }
98   }
99 
100   int32_t Init(const std::vector<JsonNodePair> &build_args);
101   int32_t Release() const;
102   int32_t FetchKernels(std::set<size_t> *out);
103   int32_t UpdateAndWait(const std::set<size_t> &ids);
104 
105   constexpr inline static size_t kMaxKernelNum_{1000};
106 
107   // allocate memory for todo_list, doing_list, done_list
108   constexpr inline static size_t kListNum_{3};
109 
110   constexpr inline static auto kKeyName_ = "./kernel_build_tmp.key";
111 
112   constexpr inline static int32_t kToDoIdx_ = 0;
113   constexpr inline static int32_t kDoingIdx_ = 1;
114   constexpr inline static int32_t kDoneIdx_ = 2;
115 
116  private:
ListBegin(int32_t list_idx)117   inline size_t *ListBegin(int32_t list_idx) { return kernel_lists_[list_idx]; }
ListBegin(int32_t list_idx)118   inline const size_t *ListBegin(int32_t list_idx) const { return kernel_lists_[list_idx]; }
119 
ListEnd(int32_t list_idx)120   inline size_t *ListEnd(int32_t list_idx) { return kernel_lists_[list_idx] + kernel_lists_[list_idx][kMaxKernelNum_]; }
ListEnd(int32_t list_idx)121   inline const size_t *ListEnd(int32_t list_idx) const {
122     return kernel_lists_[list_idx] + kernel_lists_[list_idx][kMaxKernelNum_];
123   }
124 
ResetListSize(int32_t list_idx,size_t val)125   inline void ResetListSize(int32_t list_idx, size_t val) { kernel_lists_[list_idx][kMaxKernelNum_] = val; }
126 
IncListSize(int32_t list_idx,size_t val)127   inline void IncListSize(int32_t list_idx, size_t val) { kernel_lists_[list_idx][kMaxKernelNum_] += val; }
128 
129   void *CreateSharedMem(const std::string &path);
130   std::string GetTmpKeyPath() const;
131 
InitKernelLists(void * addr)132   inline void InitKernelLists(void *addr) {
133     kernel_lists_[kToDoIdx_] = static_cast<size_t *>(addr);
134     kernel_lists_[kDoingIdx_] = kernel_lists_[kToDoIdx_] + kMaxKernelNum_ + 1;
135     kernel_lists_[kDoneIdx_] = kernel_lists_[kDoingIdx_] + kMaxKernelNum_ + 1;
136   }
137 
138   int32_t AddKernels(const std::vector<JsonNodePair> &build_args);
139   int32_t Wait() const;
140 
141   int32_t shm_id_{-1};
142   bool is_creator_{false};
143   int32_t fd_{-1};
144 
145   // includes 3 lists: todo_list, doing_list, done_list.
146   // each list has kMaxKernelNum_ + 1 elements and, the count of elements in each list
147   // is stored in kernel_lists_[xx][kMaxKernelNum_]
148   size_t *kernel_lists_[kListNum_]{nullptr, nullptr, nullptr};
149 
150   std::set<size_t> self_kernel_ids_;
151 };
152 }  // namespace kernel
153 }  // namespace mindspore
154 
155 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GRAPH_KERNEL_BUILD_H_
156