1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7
8 * http://www.apache.org/licenses/LICENSE-2.0
9
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_grpc_client.h"
17 #include <chrono>
18 namespace mindspore {
19 namespace dataset {
~CacheClientGreeter()20 CacheClientGreeter::~CacheClientGreeter() { (void)ServiceStop(); }
21
CacheClientGreeter(const std::string & hostname,int32_t port,int32_t num_connections)22 CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections)
23 : num_connections_(num_connections), request_cnt_(0), hostname_(std::move(hostname)), port_(port) {
24 grpc::ChannelArguments args;
25 // We need to bump up the message size to unlimited. The default receiving
26 // message limit is 4MB which is not big enough.
27 args.SetMaxReceiveMessageSize(-1);
28 MS_LOG(INFO) << "Hostname: " << hostname_ << ", port: " << std::to_string(port_);
29 #ifdef CACHE_LOCAL_CLIENT
30 // Try connect locally to the unix_socket first as the first preference
31 // Need to resolve hostname to ip address rather than to do a string compare
32 if (hostname == "127.0.0.1") {
33 std::string target = "unix://" + PortToUnixSocketPath(port);
34 channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
35 } else {
36 #endif
37 std::string target = hostname + ":" + std::to_string(port);
38 channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
39 #ifdef CACHE_LOCAL_CLIENT
40 }
41 #endif
42 stub_ = CacheServerGreeter::NewStub(channel_);
43 }
44
AttachToSharedMemory(bool * local_bypass)45 Status CacheClientGreeter::AttachToSharedMemory(bool *local_bypass) {
46 *local_bypass = false;
47 #ifdef CACHE_LOCAL_CLIENT
48 SharedMemory::shm_key_t shm_key;
49 RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key));
50 // Attach to the shared memory
51 mem_.SetPublicKey(shm_key);
52 RETURN_IF_NOT_OK(mem_.Attach());
53 *local_bypass = true;
54 #endif
55 return Status::OK();
56 }
57
DoServiceStart()58 Status CacheClientGreeter::DoServiceStart() {
59 RETURN_IF_NOT_OK(vg_.ServiceStart());
60 RETURN_IF_NOT_OK(DispatchWorkers(num_connections_));
61 return Status::OK();
62 }
63
DoServiceStop()64 Status CacheClientGreeter::DoServiceStop() {
65 // Shutdown the queue. We don't accept any more new incomers.
66 cq_.Shutdown();
67 // Shutdown the TaskGroup.
68 vg_.interrupt_all();
69 RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kNonBlocking));
70 // Drain the queue. We know how many requests we send out
71 while (!req_.empty()) {
72 bool success;
73 void *tag;
74 while (cq_.Next(&tag, &success)) {
75 auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
76 (void)req_.erase(r->seq_no_);
77 }
78 }
79 return Status::OK();
80 }
81
HandleRequest(std::shared_ptr<BaseRequest> rq)82 Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
83 // If there is anything extra we need to do before we send.
84 RETURN_IF_NOT_OK(rq->Prepare());
85 auto seq_no = request_cnt_.fetch_add(1);
86 auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq), seq_no);
87 // One minute timeout
88 auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(kRequestTimeoutDeadlineInSec);
89 tag->ctx_.set_deadline(deadline);
90 tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_);
91 tag->rpc_->StartCall();
92 auto ccReqTag = tag.get();
93 // Insert it into the map.
94 {
95 std::unique_lock<std::mutex> lck(mux_);
96 auto r = req_.emplace(seq_no, std::move(tag));
97 if (!r.second) {
98 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__);
99 }
100 }
101 // Last step is to tag the request.
102 ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, ccReqTag);
103 return Status::OK();
104 }
105
WorkerEntry()106 Status CacheClientGreeter::WorkerEntry() {
107 TaskManager::FindMe()->Post();
108 do {
109 bool success;
110 void *tag;
111 auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(kWaitForNewEventDeadlineInSec);
112 // Set a timeout for one second. Check for interrupt if we need to do early exit.
113 auto r = cq_.AsyncNext(&tag, &success, deadline);
114 if (r == grpc::CompletionQueue::NextStatus::GOT_EVENT) {
115 auto rq = reinterpret_cast<CacheClientRequestTag *>(tag);
116 if (success) {
117 auto &rc = rq->rc_;
118 if (!rc.ok()) {
119 auto error_code = rq->rc_.error_code();
120 std::string err_msg;
121 if (error_code == grpc::StatusCode::UNAVAILABLE) {
122 err_msg = "Cache server with port " + std::to_string(port_) +
123 " is unreachable. Make sure the server is running. GRPC Code " + std::to_string(error_code);
124 } else {
125 err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
126 }
127 Status remote_rc = Status(StatusCode::kMDNetWorkError, __LINE__, __FILE__, err_msg);
128 Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
129 }
130 // Notify the waiting thread.
131 rq->Notify();
132 }
133 {
134 // We can now free the memory
135 std::unique_lock<std::mutex> lck(mux_);
136 auto seq_no = rq->seq_no_;
137 auto n = req_.erase(seq_no);
138 CHECK_FAIL_RETURN_UNEXPECTED(n == 1, "Sequence " + std::to_string(seq_no) + " not found");
139 }
140 } else if (r == grpc::CompletionQueue::NextStatus::TIMEOUT) {
141 // If we are interrupted, exit. Otherwise wait again.
142 RETURN_IF_INTERRUPTED();
143 } else {
144 // Queue is drained.
145 break;
146 }
147 } while (true);
148 return Status::OK();
149 }
150
DispatchWorkers(int32_t num_workers)151 Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) {
152 auto f = std::bind(&CacheClientGreeter::WorkerEntry, this);
153 for (auto i = 0; i < num_workers; ++i) {
154 RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f));
155 }
156 return Status::OK();
157 }
158
159 } // namespace dataset
160 } // namespace mindspore
161