• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "debug/debugger/grpc_client.h"
17 
18 #include <thread>
19 #include <vector>
20 #include "utils/log_adapter.h"
21 
22 using debugger::Chunk;
23 using debugger::EventListener;
24 using debugger::EventReply;
25 using debugger::EventReply_Status_FAILED;
26 using debugger::GraphProto;
27 using debugger::Heartbeat;
28 using debugger::Metadata;
29 using debugger::TensorBase;
30 using debugger::TensorProto;
31 using debugger::TensorSummary;
32 using debugger::WatchpointHit;
33 
34 namespace mindspore {
GrpcClient(const std::string & host,const std::string & port)35 GrpcClient::GrpcClient(const std::string &host, const std::string &port) : stub_(nullptr) { Init(host, port); }
36 
Init(const std::string & host,const std::string & port)37 void GrpcClient::Init(const std::string &host, const std::string &port) {
38   std::string target_str = host + ":" + port;
39   MS_LOG(INFO) << "GrpcClient connecting to: " << target_str;
40 
41   std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials());
42   stub_ = EventListener::NewStub(channel);
43 }
44 
Reset()45 void GrpcClient::Reset() { stub_ = nullptr; }
46 
WaitForCommand(const Metadata & metadata)47 EventReply GrpcClient::WaitForCommand(const Metadata &metadata) {
48   EventReply reply;
49   grpc::ClientContext context;
50   grpc::Status status = stub_->WaitCMD(&context, metadata, &reply);
51   if (!status.ok()) {
52     MS_LOG(ERROR) << "RPC failed: WaitForCommand";
53     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
54     reply.set_status(EventReply_Status_FAILED);
55   }
56   return reply;
57 }
58 
SendMetadata(const Metadata & metadata)59 EventReply GrpcClient::SendMetadata(const Metadata &metadata) {
60   EventReply reply;
61   grpc::ClientContext context;
62   grpc::Status status = stub_->SendMetadata(&context, metadata, &reply);
63   if (!status.ok()) {
64     MS_LOG(ERROR) << "RPC failed: SendMetadata";
65     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
66     reply.set_status(EventReply_Status_FAILED);
67   }
68   return reply;
69 }
70 
ChunkString(std::string str,int graph_size)71 std::vector<std::string> GrpcClient::ChunkString(std::string str, int graph_size) {
72   std::vector<std::string> buf;
73   constexpr auto l_chunk_size = 1024 * 1024 * 3;
74   int size_iter = 0;
75   while (size_iter < graph_size) {
76     int chunk_size = l_chunk_size;
77     if (graph_size - size_iter < l_chunk_size) {
78       chunk_size = graph_size - size_iter;
79     }
80     std::string buffer;
81     buffer.resize(chunk_size);
82     auto err = memcpy_s(reinterpret_cast<char *>(buffer.data()), chunk_size, str.data() + size_iter, chunk_size);
83     if (err != 0) {
84       MS_LOG(EXCEPTION) << "memcpy_s failed. errorno is: " << err;
85     }
86     buf.push_back(buffer);
87     if (size_iter > INT_MAX - l_chunk_size) {
88       MS_EXCEPTION(ValueError) << size_iter << " + " << l_chunk_size << "would lead to integer overflow!";
89     }
90     size_iter += l_chunk_size;
91   }
92   return buf;
93 }
94 
SendGraph(const GraphProto & graph)95 EventReply GrpcClient::SendGraph(const GraphProto &graph) {
96   EventReply reply;
97   grpc::ClientContext context;
98   Chunk chunk;
99 
100   std::unique_ptr<grpc::ClientWriter<Chunk> > writer(stub_->SendGraph(&context, &reply));
101   std::string str = graph.SerializeAsString();
102   int graph_size = graph.ByteSize();
103   auto buf = ChunkString(str, graph_size);
104 
105   for (unsigned int i = 0; i < buf.size(); i++) {
106     MS_LOG(INFO) << "RPC:sending the " << i << "chunk in graph";
107     chunk.set_buffer(buf[i]);
108     if (!writer->Write(chunk)) {
109       break;
110     }
111     std::this_thread::sleep_for(std::chrono::milliseconds(1));
112   }
113   writer->WritesDone();
114   grpc::Status status = writer->Finish();
115   if (!status.ok()) {
116     MS_LOG(ERROR) << "RPC failed: SendGraph";
117     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
118     reply.set_status(EventReply_Status_FAILED);
119   }
120   return reply;
121 }
122 
SendMultiGraphs(const std::list<Chunk> & chunks)123 EventReply GrpcClient::SendMultiGraphs(const std::list<Chunk> &chunks) {
124   EventReply reply;
125   grpc::ClientContext context;
126 
127   std::unique_ptr<grpc::ClientWriter<Chunk> > writer(stub_->SendMultiGraphs(&context, &reply));
128   for (const auto &chunk : chunks) {
129     if (!writer->Write(chunk)) {
130       break;
131     }
132     std::this_thread::sleep_for(std::chrono::milliseconds(1));
133   }
134   writer->WritesDone();
135   grpc::Status status = writer->Finish();
136   if (!status.ok()) {
137     MS_LOG(ERROR) << "RPC failed: SendMultigraphs";
138     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
139     reply.set_status(EventReply_Status_FAILED);
140   }
141   return reply;
142 }
143 
SendTensors(const std::list<TensorProto> & tensors)144 EventReply GrpcClient::SendTensors(const std::list<TensorProto> &tensors) {
145   EventReply reply;
146   grpc::ClientContext context;
147 
148   std::unique_ptr<grpc::ClientWriter<TensorProto> > writer(stub_->SendTensors(&context, &reply));
149   for (const auto &tensor : tensors) {
150     if (!writer->Write(tensor)) {
151       break;
152     }
153     std::this_thread::sleep_for(std::chrono::milliseconds(1));
154   }
155   writer->WritesDone();
156   grpc::Status status = writer->Finish();
157   if (!status.ok()) {
158     MS_LOG(ERROR) << "RPC failed: SendTensors";
159     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
160     reply.set_status(EventReply_Status_FAILED);
161   }
162   return reply;
163 }
164 
SendWatchpointHits(const std::list<WatchpointHit> & watchpoints)165 EventReply GrpcClient::SendWatchpointHits(const std::list<WatchpointHit> &watchpoints) {
166   EventReply reply;
167   grpc::ClientContext context;
168 
169   std::unique_ptr<grpc::ClientWriter<WatchpointHit> > writer(stub_->SendWatchpointHits(&context, &reply));
170   for (const auto &watchpoint : watchpoints) {
171     if (!writer->Write(watchpoint)) {
172       break;
173     }
174     std::this_thread::sleep_for(std::chrono::milliseconds(1));
175   }
176   writer->WritesDone();
177   grpc::Status status = writer->Finish();
178   if (!status.ok()) {
179     MS_LOG(ERROR) << "RPC failed: SendWatchpointHits";
180     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
181     reply.set_status(EventReply_Status_FAILED);
182   }
183   return reply;
184 }
185 
SendHeartbeat(const Heartbeat & heartbeat)186 EventReply GrpcClient::SendHeartbeat(const Heartbeat &heartbeat) {
187   EventReply reply;
188   grpc::ClientContext context;
189 
190   grpc::Status status = stub_->SendHeartbeat(&context, heartbeat, &reply);
191   if (!status.ok()) {
192     MS_LOG(ERROR) << "RPC failed: SendHeartbeat";
193     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
194     reply.set_status(EventReply_Status_FAILED);
195   }
196   return reply;
197 }
198 
SendTensorBase(const std::list<TensorBase> & tensor_base_list)199 EventReply GrpcClient::SendTensorBase(const std::list<TensorBase> &tensor_base_list) {
200   EventReply reply;
201   grpc::ClientContext context;
202 
203   std::unique_ptr<grpc::ClientWriter<TensorBase> > writer(stub_->SendTensorBase(&context, &reply));
204 
205   for (const auto &tensor_base : tensor_base_list) {
206     if (!writer->Write(tensor_base)) {
207       break;
208     }
209     std::this_thread::sleep_for(std::chrono::milliseconds(1));
210   }
211   writer->WritesDone();
212   grpc::Status status = writer->Finish();
213   if (!status.ok()) {
214     MS_LOG(ERROR) << "RPC failed: SendTensorBase";
215     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
216     reply.set_status(EventReply_Status_FAILED);
217   }
218   return reply;
219 }
220 
SendTensorStats(const std::list<TensorSummary> & tensor_summary_list)221 EventReply GrpcClient::SendTensorStats(const std::list<TensorSummary> &tensor_summary_list) {
222   EventReply reply;
223   grpc::ClientContext context;
224 
225   std::unique_ptr<grpc::ClientWriter<TensorSummary> > writer(stub_->SendTensorStats(&context, &reply));
226 
227   for (const auto &tensor_summary : tensor_summary_list) {
228     if (!writer->Write(tensor_summary)) {
229       break;
230     }
231     std::this_thread::sleep_for(std::chrono::milliseconds(1));
232   }
233   writer->WritesDone();
234   grpc::Status status = writer->Finish();
235   if (!status.ok()) {
236     MS_LOG(ERROR) << "RPC failed: SendTensorStats";
237     MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
238     reply.set_status(EventReply_Status_FAILED);
239   }
240   return reply;
241 }
242 }  // namespace mindspore
243