1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
18 #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
19
20 #include <vector>
21 #include <string>
22 #include <cstring>
23 #include <cstdlib>
24 #include <memory>
25
26 #include "common/duplex_pipe.h"
27 #include "utils/log_adapter.h"
28 #include "utils/ms_context.h"
29
30 namespace mindspore {
31 namespace kernel {
32 void ReplaceStr(std::string *dest, const std::string &replace, char new_char);
33
34 constexpr inline static int kBufferSize = 4096;
35 constexpr inline static auto kEnv = "python";
36 // The TAG as prefix of real command from remote.
37 constexpr inline static auto kTag = "[~]";
GetPyExe()38 static std::string GetPyExe() {
39 // get real python executable path
40 auto ms_context = MsContext::GetInstance();
41 if (ms_context == nullptr) {
42 return kEnv;
43 }
44 auto env = ms_context->get_param<std::string>(MS_CTX_PYTHON_EXE_PATH);
45 if (env.empty()) {
46 return kEnv;
47 }
48 return env;
49 }
50
51 class KernelBuildClient {
52 public:
53 // Send Finish request to server
54 constexpr inline static auto kFinish = "FINISH";
55 constexpr inline static auto kAkgStart = "AKG/START";
56 constexpr inline static auto kAkgData = "AKG/DATA";
57 constexpr inline static auto kAkgAttr = "AKG/ATTR";
58 constexpr inline static auto kAkgWait = "AKG/WAIT";
59 // Receive the response from server
60 constexpr inline static auto kAck = "ACK";
61 constexpr inline static auto kErr = "ERR";
62 constexpr inline static auto kTrue = "True";
63 constexpr inline static auto kSuccess = "Success";
64
65 // Revert \n, \r, [space].
66 constexpr inline static auto kLF = "[LF]";
67 constexpr inline static auto kCR = "[CR]";
68 constexpr inline static auto kSP = "[SP]";
69
70 virtual std::string GetEnv() = 0;
71 virtual std::string GetScript() = 0;
72
Open()73 void Open() {
74 if (!init_) {
75 // Exception's thrown if open failed
76 if (dp_->Open({GetEnv(), GetScript()}, true) != -1) {
77 dp_->SetFinalizeCallback(std::make_shared<std::function<void()>>([this]() { Close(); }));
78 init_ = true;
79 }
80 }
81 }
Close()82 void Close() {
83 if (init_) {
84 dp_->Close();
85 init_ = false;
86 }
87 }
88
89 // Send a request and fetch its response
SendRequest(std::string data)90 std::string SendRequest(std::string data) {
91 Request(data);
92 return Response();
93 }
Request(std::string req)94 void Request(std::string req) {
95 if (!init_) {
96 MS_LOG(EXCEPTION) << "Try to send request before Open()";
97 }
98 MS_LOG(DEBUG) << "\t[" << req << "]";
99 *dp_ << req;
100 }
Response()101 std::string Response() {
102 if (!init_) {
103 MS_LOG(EXCEPTION) << "Try to get response before Open()";
104 }
105 std::string res;
106 *dp_ >> res;
107 // Filter out the interference
108 if (res.empty()) {
109 MS_LOG(EXCEPTION) << "Response is empty";
110 }
111 auto start = res.find(kTag);
112 if (start == std::string::npos) {
113 MS_LOG(EXCEPTION) << "Response seems incorrect, res: " << res;
114 }
115 auto pos = start + std::strlen(kTag);
116 if (pos > res.size()) { // Safe check for codedex
117 MS_LOG(EXCEPTION) << "Response seems incorrect, res(" << res.size() << "): {" << res << "}, start: " << start;
118 }
119 res = res.substr(pos);
120 // Revert the line feed and space
121 if (res != kSuccess && res != kAck && res != kErr && res != kTrue) {
122 ReplaceStr(&res, kLF, '\n');
123 ReplaceStr(&res, kSP, ' ');
124 }
125 MS_LOG(DEBUG) << "\t[" << res << "]";
126 return res;
127 }
128
129 // Run AKG building.
130 bool AkgStart(int process_num, int wait_time);
131 bool AkgSendAttr(const std::string &attr);
132 bool AkgSendData(const std::vector<std::string> &jsons);
133 bool AkgWait();
134
135 protected:
KernelBuildClient()136 KernelBuildClient() : init_(false), dp_(std::make_shared<DuplexPipe>()) {}
137 virtual ~KernelBuildClient() = default;
138
139 private:
140 bool init_;
141 std::shared_ptr<DuplexPipe> dp_;
142 };
143
GetScriptFilePath(const std::string & cmd_env,const std::string & cmd_script,const std::string & server_script)144 static std::string GetScriptFilePath(const std::string &cmd_env, const std::string &cmd_script,
145 const std::string &server_script) {
146 auto ms_context = MsContext::GetInstance();
147 MS_EXCEPTION_IF_NULL(ms_context);
148 auto server_dir = ms_context->get_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR);
149 if (!server_dir.empty()) {
150 return server_dir + server_script;
151 }
152
153 std::string cmd = cmd_env;
154 (void)cmd.append(1, ' ').append(cmd_script);
155 FILE *fpipe = popen(cmd.c_str(), "r");
156 if (fpipe == nullptr) {
157 MS_LOG(EXCEPTION) << "popen failed, errno: " << errno;
158 }
159 bool start = false;
160 std::string result;
161 char buf[kBufferSize];
162 while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) {
163 auto len = std::strlen(buf);
164 if (len == 0 || len >= kBufferSize) {
165 // Safe check for codedex
166 // Should never reach here
167 MS_LOG(EXCEPTION) << "fgets() failed, len: " << len << ", errno: " << errno;
168 }
169 if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) {
170 start = true;
171 }
172 // Filter with 'kTAG' and '\n'
173 if (start) {
174 bool line_end = buf[len - 1] == '\n';
175 result.append(buf, line_end ? len - 1 : len);
176 if (line_end) {
177 break;
178 }
179 }
180 }
181 pclose(fpipe);
182 const std::string py_suffix = ".py";
183 if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) {
184 MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}";
185 }
186 if (strlen(kTag) > result.size()) { // Safe check for codedex
187 MS_LOG(EXCEPTION) << "result size seems incorrect, result(" << result.size() << "): {" << result << "}";
188 }
189 result = result.substr(strlen(kTag));
190 MS_LOG(DEBUG) << "result: " << result;
191 return result;
192 }
193
194 class AscendKernelBuildClient : public KernelBuildClient {
195 public:
196 // Server configure
197 constexpr inline static auto kGetPathScript =
198 "-c "
199 "\""
200 "import pkgutil;"
201 "path = pkgutil"
202 ".get_loader(\\\"mindspore._extends.remote.kernel_build_server_ascend\\\")" // Server module name
203 ".get_filename();"
204 "print('[~]' + path)"
205 "\"";
206
207 constexpr inline static auto kServerScript = "kernel_build_server_ascend.py";
208
209 // Receive the response from server
210 constexpr inline static auto kFailed = "-1";
211
212 // Send building request to server
213 constexpr inline static auto kContinue = "CONTINUE"; // More transactions to be continued
214 constexpr inline static auto kTbePre = "TBE/PRE";
215 constexpr inline static auto kTbeStart = "TBE/START";
216 constexpr inline static auto kTbeWait = "TBE/WAIT";
217 constexpr inline static auto kTbeReset = "TBE/RESET";
218 constexpr inline static auto kTbeTune = "TBE/TUNE";
219 constexpr inline static auto kTbeJob = "TBE/JOB";
220
221 // Send server info. query to server
222 constexpr inline static auto kFormat = "FORMAT";
223 constexpr inline static auto kSupport = "SUPPORT";
224
Instance()225 static AscendKernelBuildClient &Instance() {
226 static AscendKernelBuildClient instance;
227 return instance;
228 }
229
GetEnv()230 std::string GetEnv() override { return GetPyExe(); }
231
GetScript()232 std::string GetScript() override {
233 auto env = GetPyExe();
234 return GetScriptFilePath(env, kGetPathScript, kServerScript);
235 }
236
237 // Before building.
238 std::string SelectFormat(const std::string &json);
239 bool CheckSupported(const std::string &json);
240
241 // Run TBE building.
242 std::string TbeSendJob(const std::string &json);
243 int TbeStart(const std::string &json, const std::string &mode);
244 bool TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result);
245 void TbeReset();
246
247 AscendKernelBuildClient(const AscendKernelBuildClient &) = delete;
248 AscendKernelBuildClient &operator=(const AscendKernelBuildClient &) = delete;
249
250 AscendKernelBuildClient(AscendKernelBuildClient &&) = delete;
251 AscendKernelBuildClient &operator=(AscendKernelBuildClient &&) = delete;
252
253 private:
254 void TbePre(const std::string &mode);
AscendKernelBuildClient()255 AscendKernelBuildClient() { Open(); }
~AscendKernelBuildClient()256 ~AscendKernelBuildClient() override { Close(); }
257 };
258
259 class GpuKernelBuildClient : public KernelBuildClient {
260 public:
261 // Server configure
262 constexpr inline static auto kGetPathScript =
263 "-c "
264 "\""
265 "import pkgutil;"
266 "path = pkgutil"
267 ".get_loader(\\\"mindspore._extends.remote.kernel_build_server_gpu\\\")" // Server module name
268 ".get_filename();"
269 "print('[~]' + path)"
270 "\"";
271
272 constexpr inline static auto kServerScript = "kernel_build_server_gpu.py";
273
Instance()274 static GpuKernelBuildClient &Instance() {
275 static GpuKernelBuildClient instance;
276 return instance;
277 }
278
GetEnv()279 std::string GetEnv() override { return GetPyExe(); }
280
GetScript()281 std::string GetScript() override {
282 auto env = GetPyExe();
283 return GetScriptFilePath(env, kGetPathScript, kServerScript);
284 }
285
286 GpuKernelBuildClient(const GpuKernelBuildClient &) = delete;
287 GpuKernelBuildClient &operator=(const GpuKernelBuildClient &) = delete;
288
289 GpuKernelBuildClient(GpuKernelBuildClient &&) = delete;
290 GpuKernelBuildClient &operator=(GpuKernelBuildClient &&) = delete;
291
292 private:
GpuKernelBuildClient()293 GpuKernelBuildClient() { Open(); }
~GpuKernelBuildClient()294 ~GpuKernelBuildClient() override { Close(); }
295 };
296 } // namespace kernel
297 } // namespace mindspore
298
299 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
300