1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "cxx_api/model/model_converter_utils/multi_process.h"
18 #include <unistd.h>
19 #include <sys/wait.h>
20 #include <algorithm>
21 #include <vector>
22 #include <thread>
23 #include "mindspore/core/utils/log_adapter.h"
24 #include "cxx_api/model/model_converter_utils/shared_memory.h"
25
26 namespace mindspore {
27 namespace {
28 constexpr uint64_t kSharedMemorySize = 100ull << 20; // 100 MB
29 constexpr timespec kOneMillisecond = {
30 0, // 0 seconds
31 1 * 1000L * 1000L, // And 1 ms
32 };
33
34 constexpr timespec kOneHundredMilliseconds = {
35 0, // 0 seconds
36 100 * 1000L * 1000L, // And 100 ms
37 };
38 } // namespace
39
40 MultiProcess::MultiProcess() = default;
41
42 MultiProcess::~MultiProcess() = default;
43
MainProcess(const ProcessFuncCall & parent_process,const ProcessFuncCall & child_process)44 Status MultiProcess::MainProcess(const ProcessFuncCall &parent_process, const ProcessFuncCall &child_process) {
45 MS_EXCEPTION_IF_NULL(parent_process);
46 MS_EXCEPTION_IF_NULL(child_process);
47 Status ret;
48 memory_size_ = kSharedMemorySize; // 100 MB
49 SharedMemory shared_memory;
50 ret = shared_memory.Create(memory_size_);
51 if (ret != kSuccess) {
52 MS_LOG(ERROR) << "Create shared memory failed";
53 return ret;
54 }
55 pid_t pid = fork();
56 if (pid < 0) {
57 shared_memory.Destroy();
58 MS_LOG(ERROR) << "Fork process to convert model failed";
59 return kMEFailed;
60 }
61 ret = shared_memory.Attach();
62 if (ret != kSuccess) {
63 MS_LOG(ERROR) << "Process attach shared memory failed, pid " << pid;
64 return ret;
65 }
66 shmat_addr_ = shared_memory.GetSharedMemoryAddr();
67 if (shmat_addr_ == nullptr) {
68 MS_LOG(ERROR) << "Get shared memory failed";
69 return ret;
70 }
71 constexpr size_t kMsgStructNum = 2;
72 shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * kMsgStructNum;
73 shmat_data_max_size_ =
74 memory_size_ - (reinterpret_cast<uintptr_t>(shmat_data_addr_) - reinterpret_cast<uintptr_t>(shmat_addr_));
75 MS_LOG(INFO) << "Shm addr " << reinterpret_cast<uintptr_t>(shmat_addr_);
76 if (pid == 0) {
77 ChildProcess(child_process);
78 shared_memory.Detach();
79 MS_LOG(INFO) << "Model converter: child process sleep waiting for exit signal.";
80 while (1) {
81 // waiting for signal
82 }
83 } else { // parent process
84 ret = ParentProcess(parent_process);
85 shared_memory.Detach();
86
87 MS_LOG(INFO) << "Model converter: parent process kills child of fork.";
88 (void)kill(pid, SIGKILL);
89 constexpr uint32_t kMaxLoopCount = 5;
90 bool child_exited = false;
91 for (uint32_t i = 0; i < kMaxLoopCount; ++i) {
92 int status;
93 if (waitpid(pid, &status, WNOHANG) == pid) {
94 MS_LOG(INFO) << "Child process " << pid << " exits success.";
95 child_exited = true;
96 break;
97 }
98 (void)sleep(1);
99 }
100 if (!child_exited) {
101 MS_LOG(WARNING) << "Child process " << pid << " has been killed but waitpid failed.";
102 }
103 shared_memory.Destroy();
104 }
105 return ret;
106 }
107
ParentProcess(const ProcessFuncCall & parent_process)108 Status MultiProcess::ParentProcess(const ProcessFuncCall &parent_process) {
109 auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
110 auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
111 send_msg_ = parent_msg;
112 receive_msg_ = child_msg;
113 std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
114 Status ret;
115 try {
116 ret = parent_process(this);
117 if (ret != kSuccess) {
118 MS_LOG(ERROR) << "Parent process process failed";
119 }
120 } catch (const std::runtime_error &ex) {
121 MS_LOG(ERROR) << "Catch parent process runtime error: " << ex.what();
122 ret = kMEFailed;
123 }
124 stopped_ = true;
125 send_msg_->stop = 1;
126 heartbeat_thread.join();
127 return ret;
128 }
129
ChildProcess(const ProcessFuncCall & child_process)130 void MultiProcess::ChildProcess(const ProcessFuncCall &child_process) {
131 auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
132 auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
133 send_msg_ = child_msg;
134 receive_msg_ = parent_msg;
135 std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
136 try {
137 MS_EXCEPTION_IF_NULL(child_process);
138 auto ret = child_process(this);
139 if (ret != kSuccess) {
140 MS_LOG(ERROR) << "Child process process failed";
141 }
142 } catch (const std::runtime_error &ex) {
143 MS_LOG(ERROR) << "Catch child process runtime error: " << ex.what();
144 }
145 stopped_ = true;
146 send_msg_->stop = 1;
147 heartbeat_thread.join();
148 }
149
SendMsg(const void * buffer,uint64_t msg_len)150 Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) {
151 MS_EXCEPTION_IF_NULL(buffer);
152 MS_LOG(INFO) << "Start to send message to peer process, msg len " << msg_len;
153 send_msg_->msg_total_len = msg_len;
154 uint64_t cur_offset = 0;
155 while (msg_len > cur_offset) {
156 uint64_t sub_msg_len = std::min(msg_len - cur_offset, shmat_data_max_size_);
157 if (sub_msg_len == 0) {
158 MS_LOG(ERROR) << "Invalid message len " << sub_msg_len;
159 return kMEFailed;
160 }
161 auto ret =
162 memcpy_s(shmat_data_addr_, shmat_data_max_size_, static_cast<const uint8_t *>(buffer) + cur_offset, sub_msg_len);
163 if (ret != EOK) {
164 MS_LOG(ERROR) << "memcpy_s failed, ret = " << ret;
165 return kMEFailed;
166 }
167 cur_offset += sub_msg_len;
168
169 send_msg_->msg_len = sub_msg_len;
170 send_msg_->read_finish_flag = 0;
171 send_msg_->read_ready_flag = 1;
172 MS_LOG(INFO) << "Send start " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
173 while (!send_msg_->read_finish_flag && !peer_stopped_) {
174 (void)nanosleep(&kOneMillisecond, nullptr); // 1ms
175 }
176 if (peer_stopped_) {
177 if (!send_msg_->read_finish_flag) {
178 return kMEFailed;
179 }
180 break;
181 }
182 MS_LOG(INFO) << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
183 }
184 MS_LOG(INFO) << "End to send message to peer process, msg len " << msg_len;
185 return kSuccess;
186 }
187
ReceiveMsg(const CreateBufferCall & create_buffer_call) const188 Status MultiProcess::ReceiveMsg(const CreateBufferCall &create_buffer_call) const {
189 uint64_t cur_offset = 0;
190 uint8_t *msg_buffer = nullptr;
191 uint64_t msg_len = 0;
192 do {
193 MS_LOG(INFO) << "Receive start from " << cur_offset;
194 while (!receive_msg_->read_ready_flag && !peer_stopped_) {
195 (void)nanosleep(&kOneMillisecond, nullptr); // 1ms
196 }
197 if (peer_stopped_) {
198 return kMEFailed;
199 }
200 if (msg_buffer == nullptr) {
201 msg_len = receive_msg_->msg_total_len;
202 msg_buffer = create_buffer_call(msg_len);
203 }
204 MS_EXCEPTION_IF_NULL(msg_buffer);
205 size_t dest_max = std::min(shmat_data_max_size_, msg_len - cur_offset);
206 auto ret = memcpy_s(msg_buffer + cur_offset, dest_max, shmat_data_addr_, receive_msg_->msg_len);
207 if (ret != EOK) {
208 MS_LOG(INFO) << "memcpy_s failed, ret = " << ret;
209 return kMEFailed;
210 }
211 cur_offset += receive_msg_->msg_len;
212 receive_msg_->read_ready_flag = 0;
213 receive_msg_->read_finish_flag = 1;
214 MS_LOG(INFO) << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl;
215 } while (msg_len > cur_offset);
216 return kSuccess;
217 }
218
HeartbeatThreadFunc(MultiProcess * multi_process)219 void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); }
220
HeartbeatThreadFuncInner()221 void MultiProcess::HeartbeatThreadFuncInner() {
222 constexpr uint64_t kOvertime = 1024;
223 uint64_t last_beat_cnt = 0;
224 uint64_t repeat_cnt = 0;
225 while (!stopped_) {
226 if (receive_msg_->stop) {
227 peer_stopped_ = true;
228 MS_LOG(WARNING) << "Peer stopped";
229 break;
230 }
231 uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt;
232 if (heartbeat_gap > 0 && heartbeat_gap < kOvertime) {
233 last_beat_cnt = receive_msg_->heartbeat;
234 repeat_cnt = 0;
235 } else {
236 repeat_cnt++;
237 if (repeat_cnt > 30) { // 30*100ms = 3s no reply
238 peer_stopped_ = true;
239 MS_LOG(WARNING) << "Peer stopped";
240 break;
241 }
242 }
243 send_msg_->heartbeat += 1;
244 (void)nanosleep(&kOneHundredMilliseconds, nullptr); // sleep 100 ms
245 }
246 }
247 } // namespace mindspore
248