• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/session/kernel_build_client.h"
18 #include <memory>
19 
20 namespace mindspore {
21 namespace kernel {
22 inline static bool init_flag = false;
ReplaceStr(std::string * dest,const std::string & replace,char new_char)23 void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
24   std::string::size_type start = 0;
25   while ((start = (*dest).find(replace, start)) != std::string::npos) {
26     (*dest).replace(start, replace.size(), 1, new_char);
27     start++;  // Replaced 1 character.
28   }
29 }
30 
AkgStart(int process_num,int wait_time)31 bool KernelBuildClient::AkgStart(int process_num, int wait_time) {
32   // Start compiling..
33   auto res = SendRequest(kAkgStart);
34   if (res != kAck) {
35     MS_LOG(ERROR) << "AKG/START failed, res: " << res;
36     return false;
37   }
38   std::string process_num_str = std::to_string(process_num);
39   res = SendRequest(process_num_str);
40   if (res != kAck) {
41     MS_LOG(ERROR) << "AKG/START(process_num) responds failed, res: " << res;
42     return false;
43   }
44   std::string wait_time_str = std::to_string(wait_time);
45   res = SendRequest(wait_time_str);
46   if (res != kAck) {
47     MS_LOG(ERROR) << "AKG/START(wait_time) responds failed, res: " << res;
48     return false;
49   }
50   return true;
51 }
52 
AkgSendAttr(const std::string & attr)53 bool KernelBuildClient::AkgSendAttr(const std::string &attr) {
54   auto res = SendRequest(kAkgAttr);
55   if (res != kAck) {
56     MS_LOG(ERROR) << "AKG/ATTR failed, res: " << res;
57     return false;
58   }
59   res = SendRequest(attr);
60   if (res != kAck) {
61     MS_LOG(ERROR) << "AKG/ATTR.. responds failed, res: " << res << ", when sending [" << attr << "]";
62     return false;
63   }
64   return true;
65 }
66 
AkgSendData(const std::vector<std::string> & jsons)67 bool KernelBuildClient::AkgSendData(const std::vector<std::string> &jsons) {
68   auto res = SendRequest(kAkgData);
69   if (res != kAck) {
70     MS_LOG(ERROR) << "AKG/DATA failed, res: " << res;
71     return false;
72   }
73   for (auto &json : jsons) {
74     res = SendRequest(json);
75     if (res != kAck) {
76       MS_LOG(ERROR) << "AKG/DATA.. responds failed, res: " << res << ", when sending [" << json << "]";
77       return false;
78     }
79   }
80   return true;
81 }
82 
83 // Fetch the result of AKG compiling.
AkgWait()84 bool KernelBuildClient::AkgWait() {
85   auto res = SendRequest(kAkgWait);
86   if (res != kTrue) {
87     MS_LOG(ERROR) << "AKG/WAIT failed, res: " << res;
88     return false;
89   }
90   return true;
91 }
92 
TbePre(const std::string & mode)93 void AscendKernelBuildClient::TbePre(const std::string &mode) {
94   auto res = SendRequest(kTbePre);
95   if (res.find(kSuccess) == std::string::npos) {
96     MS_LOG(EXCEPTION) << "PRE failed, res: " << res;
97   }
98   MS_LOG(INFO) << "Pre " << res;
99   // init env for auto tune
100   res = SendRequest(kTbeTune);
101   if (res != kAck) {
102     MS_LOG(EXCEPTION) << "Send tune single failed, res: " << res;
103   }
104   res = SendRequest(mode);
105   if (res != kSuccess) {
106     MS_LOG(EXCEPTION) << "PRE failed, res: " << res;
107   }
108 }
109 
TbeStart(const std::string & json,const std::string & mode)110 int AscendKernelBuildClient::TbeStart(const std::string &json, const std::string &mode) {
111   if (!init_flag) {
112     TbePre(mode);
113     init_flag = true;
114   }
115   // Start compiling..
116   auto res = SendRequest(kTbeStart);
117   if (res != kAck) {
118     MS_LOG(ERROR) << "START failed, res: " << res;
119     return -1;
120   }
121   // Send the json data.
122   res = SendRequest(json);
123   if (res == kFailed) {
124     MS_LOG(ERROR) << "TBE/START responds failed, res: " << res;
125     return -1;
126   }
127   // Return task id.
128   return std::stoi(res);
129 }
130 
TbeSendJob(const std::string & json)131 std::string AscendKernelBuildClient::TbeSendJob(const std::string &json) {
132   auto res = SendRequest(kTbeJob);
133   if (res != kAck) {
134     MS_LOG(ERROR) << "Send TBE job failed, res: " << res;
135     return "";
136   }
137   // Send the json data.
138   res = SendRequest(json);
139   if (res == kFailed) {
140     MS_LOG(ERROR) << "Send TBE job json failed, res: " << res;
141     return "";
142   }
143   return res;
144 }
145 
TbeWait(int * task_id,std::string * task_result,std::string * pre_build_result)146 bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) {
147   // Start waiting..
148   auto res = SendRequest(kTbeWait);
149   if (res != kAck) {
150     MS_LOG(ERROR) << "TBE/WAIT failed, res: " << res;
151     return false;
152   }
153   // Request task id.
154   *task_id = std::stoi(SendRequest(kContinue));
155   // Request task result.
156   *task_result = SendRequest(kContinue);
157   // Request prebuild result.
158   *pre_build_result = SendRequest(kContinue);
159   return true;
160 }
161 
TbeReset()162 void AscendKernelBuildClient::TbeReset() {
163   // Start compiling..
164   init_flag = false;
165   auto res = SendRequest(kTbeReset);
166   if (res != kAck) {
167     MS_LOG(EXCEPTION) << "TBE/RESET response is: " << res;
168   }
169 }
170 
SelectFormat(const std::string & json)171 std::string AscendKernelBuildClient::SelectFormat(const std::string &json) {
172   // Start compiling..
173   auto res = SendRequest(kFormat);
174   if (res != kAck) {
175     MS_LOG(ERROR) << "FORMAT failed, res: " << res;
176     return "";
177   }
178   // Send the json data.
179   res = SendRequest(json);
180   if (res == kErr) {
181     MS_LOG(ERROR) << "FORMAT responds failed, res: " << res;
182     return "";
183   }
184   return res;
185 }
186 
CheckSupported(const std::string & json)187 bool AscendKernelBuildClient::CheckSupported(const std::string &json) {
188   // Checking support..
189   auto res = SendRequest(kSupport);
190   if (res != kAck) {
191     MS_LOG(ERROR) << "SUPPORT failed, res: " << res;
192     return false;
193   }
194   // Send the json data.
195   res = SendRequest(json);
196   if (res != kTrue) {
197     MS_LOG(INFO) << "SUPPORT responds failed, res: " << res;
198     return false;
199   }
200   return true;
201 }
202 }  // namespace kernel
203 }  // namespace mindspore
204