• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "cxx_api/model/model_converter_utils/multi_process.h"
18 #include <unistd.h>
19 #include <sys/wait.h>
20 #include <algorithm>
21 #include <vector>
22 #include <thread>
23 #include "mindspore/core/utils/log_adapter.h"
24 #include "cxx_api/model/model_converter_utils/shared_memory.h"
25 
26 namespace mindspore {
27 namespace {
28 constexpr uint64_t kSharedMemorySize = 100ull << 20;  // 100 MB
29 constexpr timespec kOneMillisecond = {
30   0,                  // 0 seconds
31   1 * 1000L * 1000L,  // And 1 ms
32 };
33 
34 constexpr timespec kOneHundredMilliseconds = {
35   0,                    // 0 seconds
36   100 * 1000L * 1000L,  // And 100 ms
37 };
38 }  // namespace
39 
40 MultiProcess::MultiProcess() = default;
41 
42 MultiProcess::~MultiProcess() = default;
43 
MainProcess(const ProcessFuncCall & parent_process,const ProcessFuncCall & child_process)44 Status MultiProcess::MainProcess(const ProcessFuncCall &parent_process, const ProcessFuncCall &child_process) {
45   MS_EXCEPTION_IF_NULL(parent_process);
46   MS_EXCEPTION_IF_NULL(child_process);
47   Status ret;
48   memory_size_ = kSharedMemorySize;  // 100 MB
49   SharedMemory shared_memory;
50   ret = shared_memory.Create(memory_size_);
51   if (ret != kSuccess) {
52     MS_LOG(ERROR) << "Create shared memory failed";
53     return ret;
54   }
55   pid_t pid = fork();
56   if (pid < 0) {
57     shared_memory.Destroy();
58     MS_LOG(ERROR) << "Fork process to convert model failed";
59     return kMEFailed;
60   }
61   ret = shared_memory.Attach();
62   if (ret != kSuccess) {
63     MS_LOG(ERROR) << "Process attach shared memory failed, pid " << pid;
64     return ret;
65   }
66   shmat_addr_ = shared_memory.GetSharedMemoryAddr();
67   if (shmat_addr_ == nullptr) {
68     MS_LOG(ERROR) << "Get shared memory failed";
69     return ret;
70   }
71   constexpr size_t kMsgStructNum = 2;
72   shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * kMsgStructNum;
73   shmat_data_max_size_ =
74     memory_size_ - (reinterpret_cast<uintptr_t>(shmat_data_addr_) - reinterpret_cast<uintptr_t>(shmat_addr_));
75   MS_LOG(INFO) << "Shm addr " << reinterpret_cast<uintptr_t>(shmat_addr_);
76   if (pid == 0) {
77     ChildProcess(child_process);
78     shared_memory.Detach();
79     MS_LOG(INFO) << "Model converter: child process sleep waiting for exit signal.";
80     while (1) {
81       // waiting for signal
82     }
83   } else {  // parent process
84     ret = ParentProcess(parent_process);
85     shared_memory.Detach();
86 
87     MS_LOG(INFO) << "Model converter: parent process kills child of fork.";
88     (void)kill(pid, SIGKILL);
89     constexpr uint32_t kMaxLoopCount = 5;
90     bool child_exited = false;
91     for (uint32_t i = 0; i < kMaxLoopCount; ++i) {
92       int status;
93       if (waitpid(pid, &status, WNOHANG) == pid) {
94         MS_LOG(INFO) << "Child process " << pid << " exits success.";
95         child_exited = true;
96         break;
97       }
98       (void)sleep(1);
99     }
100     if (!child_exited) {
101       MS_LOG(WARNING) << "Child process " << pid << " has been killed but waitpid failed.";
102     }
103     shared_memory.Destroy();
104   }
105   return ret;
106 }
107 
ParentProcess(const ProcessFuncCall & parent_process)108 Status MultiProcess::ParentProcess(const ProcessFuncCall &parent_process) {
109   auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
110   auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
111   send_msg_ = parent_msg;
112   receive_msg_ = child_msg;
113   std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
114   Status ret;
115   try {
116     ret = parent_process(this);
117     if (ret != kSuccess) {
118       MS_LOG(ERROR) << "Parent process process failed";
119     }
120   } catch (const std::runtime_error &ex) {
121     MS_LOG(ERROR) << "Catch parent process runtime error: " << ex.what();
122     ret = kMEFailed;
123   }
124   stopped_ = true;
125   send_msg_->stop = 1;
126   heartbeat_thread.join();
127   return ret;
128 }
129 
ChildProcess(const ProcessFuncCall & child_process)130 void MultiProcess::ChildProcess(const ProcessFuncCall &child_process) {
131   auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
132   auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
133   send_msg_ = child_msg;
134   receive_msg_ = parent_msg;
135   std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
136   try {
137     MS_EXCEPTION_IF_NULL(child_process);
138     auto ret = child_process(this);
139     if (ret != kSuccess) {
140       MS_LOG(ERROR) << "Child process process failed";
141     }
142   } catch (const std::runtime_error &ex) {
143     MS_LOG(ERROR) << "Catch child process runtime error: " << ex.what();
144   }
145   stopped_ = true;
146   send_msg_->stop = 1;
147   heartbeat_thread.join();
148 }
149 
SendMsg(const void * buffer,uint64_t msg_len)150 Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) {
151   MS_EXCEPTION_IF_NULL(buffer);
152   MS_LOG(INFO) << "Start to send message to peer process, msg len " << msg_len;
153   send_msg_->msg_total_len = msg_len;
154   uint64_t cur_offset = 0;
155   while (msg_len > cur_offset) {
156     uint64_t sub_msg_len = std::min(msg_len - cur_offset, shmat_data_max_size_);
157     if (sub_msg_len == 0) {
158       MS_LOG(ERROR) << "Invalid message len " << sub_msg_len;
159       return kMEFailed;
160     }
161     auto ret =
162       memcpy_s(shmat_data_addr_, shmat_data_max_size_, static_cast<const uint8_t *>(buffer) + cur_offset, sub_msg_len);
163     if (ret != EOK) {
164       MS_LOG(ERROR) << "memcpy_s failed, ret = " << ret;
165       return kMEFailed;
166     }
167     cur_offset += sub_msg_len;
168 
169     send_msg_->msg_len = sub_msg_len;
170     send_msg_->read_finish_flag = 0;
171     send_msg_->read_ready_flag = 1;
172     MS_LOG(INFO) << "Send start " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
173     while (!send_msg_->read_finish_flag && !peer_stopped_) {
174       (void)nanosleep(&kOneMillisecond, nullptr);  // 1ms
175     }
176     if (peer_stopped_) {
177       if (!send_msg_->read_finish_flag) {
178         return kMEFailed;
179       }
180       break;
181     }
182     MS_LOG(INFO) << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
183   }
184   MS_LOG(INFO) << "End to send message to peer process, msg len " << msg_len;
185   return kSuccess;
186 }
187 
ReceiveMsg(const CreateBufferCall & create_buffer_call) const188 Status MultiProcess::ReceiveMsg(const CreateBufferCall &create_buffer_call) const {
189   uint64_t cur_offset = 0;
190   uint8_t *msg_buffer = nullptr;
191   uint64_t msg_len = 0;
192   do {
193     MS_LOG(INFO) << "Receive start from " << cur_offset;
194     while (!receive_msg_->read_ready_flag && !peer_stopped_) {
195       (void)nanosleep(&kOneMillisecond, nullptr);  // 1ms
196     }
197     if (peer_stopped_) {
198       return kMEFailed;
199     }
200     if (msg_buffer == nullptr) {
201       msg_len = receive_msg_->msg_total_len;
202       msg_buffer = create_buffer_call(msg_len);
203     }
204     MS_EXCEPTION_IF_NULL(msg_buffer);
205     size_t dest_max = std::min(shmat_data_max_size_, msg_len - cur_offset);
206     auto ret = memcpy_s(msg_buffer + cur_offset, dest_max, shmat_data_addr_, receive_msg_->msg_len);
207     if (ret != EOK) {
208       MS_LOG(INFO) << "memcpy_s failed, ret = " << ret;
209       return kMEFailed;
210     }
211     cur_offset += receive_msg_->msg_len;
212     receive_msg_->read_ready_flag = 0;
213     receive_msg_->read_finish_flag = 1;
214     MS_LOG(INFO) << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl;
215   } while (msg_len > cur_offset);
216   return kSuccess;
217 }
218 
HeartbeatThreadFunc(MultiProcess * multi_process)219 void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); }
220 
HeartbeatThreadFuncInner()221 void MultiProcess::HeartbeatThreadFuncInner() {
222   constexpr uint64_t kOvertime = 1024;
223   uint64_t last_beat_cnt = 0;
224   uint64_t repeat_cnt = 0;
225   while (!stopped_) {
226     if (receive_msg_->stop) {
227       peer_stopped_ = true;
228       MS_LOG(WARNING) << "Peer stopped";
229       break;
230     }
231     uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt;
232     if (heartbeat_gap > 0 && heartbeat_gap < kOvertime) {
233       last_beat_cnt = receive_msg_->heartbeat;
234       repeat_cnt = 0;
235     } else {
236       repeat_cnt++;
237       if (repeat_cnt > 30) {  // 30*100ms = 3s no reply
238         peer_stopped_ = true;
239         MS_LOG(WARNING) << "Peer stopped";
240         break;
241       }
242     }
243     send_msg_->heartbeat += 1;
244     (void)nanosleep(&kOneHundredMilliseconds, nullptr);  // sleep 100 ms
245   }
246 }
247 }  // namespace mindspore
248