• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_grpc_client.h"
17 #include <chrono>
18 namespace mindspore {
19 namespace dataset {
~CacheClientGreeter()20 CacheClientGreeter::~CacheClientGreeter() { (void)ServiceStop(); }
21 
CacheClientGreeter(const std::string & hostname,int32_t port,int32_t num_connections)22 CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections)
23     : num_connections_(num_connections), request_cnt_(0), hostname_(std::move(hostname)), port_(port) {
24   grpc::ChannelArguments args;
25   // We need to bump up the message size to unlimited. The default receiving
26   // message limit is 4MB which is not big enough.
27   args.SetMaxReceiveMessageSize(-1);
28   MS_LOG(INFO) << "Hostname: " << hostname_ << ", port: " << std::to_string(port_);
29 #ifdef CACHE_LOCAL_CLIENT
30   // Try connect locally to the unix_socket first as the first preference
31   // Need to resolve hostname to ip address rather than to do a string compare
32   if (hostname == "127.0.0.1") {
33     std::string target = "unix://" + PortToUnixSocketPath(port);
34     channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
35   } else {
36 #endif
37     std::string target = hostname + ":" + std::to_string(port);
38     channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
39 #ifdef CACHE_LOCAL_CLIENT
40   }
41 #endif
42   stub_ = CacheServerGreeter::NewStub(channel_);
43 }
44 
AttachToSharedMemory(bool * local_bypass)45 Status CacheClientGreeter::AttachToSharedMemory(bool *local_bypass) {
46   *local_bypass = false;
47 #ifdef CACHE_LOCAL_CLIENT
48   SharedMemory::shm_key_t shm_key;
49   RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key));
50   // Attach to the shared memory
51   mem_.SetPublicKey(shm_key);
52   RETURN_IF_NOT_OK(mem_.Attach());
53   *local_bypass = true;
54 #endif
55   return Status::OK();
56 }
57 
DoServiceStart()58 Status CacheClientGreeter::DoServiceStart() {
59   RETURN_IF_NOT_OK(vg_.ServiceStart());
60   RETURN_IF_NOT_OK(DispatchWorkers(num_connections_));
61   return Status::OK();
62 }
63 
DoServiceStop()64 Status CacheClientGreeter::DoServiceStop() {
65   // Shutdown the queue. We don't accept any more new incomers.
66   cq_.Shutdown();
67   // Shutdown the TaskGroup.
68   vg_.interrupt_all();
69   RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kNonBlocking));
70   // Drain the queue. We know how many requests we send out
71   while (!req_.empty()) {
72     bool success;
73     void *tag;
74     while (cq_.Next(&tag, &success)) {
75       auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
76       (void)req_.erase(r->seq_no_);
77     }
78   }
79   return Status::OK();
80 }
81 
HandleRequest(std::shared_ptr<BaseRequest> rq)82 Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
83   // If there is anything extra we need to do before we send.
84   RETURN_IF_NOT_OK(rq->Prepare());
85   auto seq_no = request_cnt_.fetch_add(1);
86   auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq), seq_no);
87   // One minute timeout
88   auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(kRequestTimeoutDeadlineInSec);
89   tag->ctx_.set_deadline(deadline);
90   tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_);
91   tag->rpc_->StartCall();
92   auto ccReqTag = tag.get();
93   // Insert it into the map.
94   {
95     std::unique_lock<std::mutex> lck(mux_);
96     auto r = req_.emplace(seq_no, std::move(tag));
97     if (!r.second) {
98       return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__);
99     }
100   }
101   // Last step is to tag the request.
102   ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, ccReqTag);
103   return Status::OK();
104 }
105 
WorkerEntry()106 Status CacheClientGreeter::WorkerEntry() {
107   TaskManager::FindMe()->Post();
108   do {
109     bool success;
110     void *tag;
111     auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(kWaitForNewEventDeadlineInSec);
112     // Set a timeout for one second. Check for interrupt if we need to do early exit.
113     auto r = cq_.AsyncNext(&tag, &success, deadline);
114     if (r == grpc::CompletionQueue::NextStatus::GOT_EVENT) {
115       auto rq = reinterpret_cast<CacheClientRequestTag *>(tag);
116       if (success) {
117         auto &rc = rq->rc_;
118         if (!rc.ok()) {
119           auto error_code = rq->rc_.error_code();
120           std::string err_msg;
121           if (error_code == grpc::StatusCode::UNAVAILABLE) {
122             err_msg = "Cache server with port " + std::to_string(port_) +
123                       " is unreachable. Make sure the server is running. GRPC Code " + std::to_string(error_code);
124           } else {
125             err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
126           }
127           Status remote_rc = Status(StatusCode::kMDNetWorkError, __LINE__, __FILE__, err_msg);
128           Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
129         }
130         // Notify the waiting thread.
131         rq->Notify();
132       }
133       {
134         // We can now free the memory
135         std::unique_lock<std::mutex> lck(mux_);
136         auto seq_no = rq->seq_no_;
137         auto n = req_.erase(seq_no);
138         CHECK_FAIL_RETURN_UNEXPECTED(n == 1, "Sequence " + std::to_string(seq_no) + " not found");
139       }
140     } else if (r == grpc::CompletionQueue::NextStatus::TIMEOUT) {
141       // If we are interrupted, exit. Otherwise wait again.
142       RETURN_IF_INTERRUPTED();
143     } else {
144       // Queue is drained.
145       break;
146     }
147   } while (true);
148   return Status::OK();
149 }
150 
DispatchWorkers(int32_t num_workers)151 Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) {
152   auto f = std::bind(&CacheClientGreeter::WorkerEntry, this);
153   for (auto i = 0; i < num_workers; ++i) {
154     RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f));
155   }
156   return Status::OK();
157 }
158 
159 }  // namespace dataset
160 }  // namespace mindspore
161