• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
17 
18 #include "minddata/dataset/core/config_manager.h"
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/global_context.h"
21 #include "minddata/dataset/engine/execution_tree.h"
22 #include "minddata/dataset/util/system_pool.h"
23 #include "minddata/dataset/util/task_manager.h"
24 
25 namespace mindspore {
26 namespace dataset {
27 CacheMergeOp::~CacheMergeOp() = default;
Print(std::ostream & out,bool show_all) const28 void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
29   if (!show_all) {
30     // Call the super class for displaying any common 1-liner info
31     ParallelOp::Print(out, show_all);
32     // Then show any custom derived-internal 1-liner info for this op
33     out << "\n";
34   } else {
35     // Call the super class for displaying any common detailed info
36     ParallelOp::Print(out, show_all);
37     // Then show any custom derived-internal stuff
38     out << "\n\n";
39   }
40 }
41 
CacheMergeOp(int32_t numWorkers,int32_t opConnectorSize,int32_t numCleaners,std::shared_ptr<CacheClient> cache_client)42 CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
43                            std::shared_ptr<CacheClient> cache_client)
44     : ParallelOp(numWorkers, opConnectorSize),
45       num_cleaners_(numCleaners),
46       cache_client_(std::move(cache_client)),
47       cache_missing_rows_(true) {}
48 
operator ()()49 Status CacheMergeOp::operator()() {
50   // A queue of row id to let cleaner send cache miss rows to the cache server
51   // We don't want a small queue as this will block the parallel op workers.
52   // A row id is 8 byte integer. So bigger size doesn't consume a lot of memory.
53   static const int32_t queue_sz = 512;
54   io_que_ = std::make_unique<Queue<row_id_type>>(queue_sz);
55   RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks()));
56   RETURN_IF_NOT_OK(tree_->LaunchWorkers(
57     num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry", id()));
58   RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
59                                         std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1),
60                                         Name() + "::CacheMissWorkerEntry", id()));
61   // One dedicated thread to move TensorRow from the pool to the cache server
62   for (auto i = 0; i < num_cleaners_; ++i) {
63     RETURN_IF_NOT_OK(
64       tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this), nullptr, id()));
65   }
66   TaskManager::FindMe()->Post();
67   return Status::OK();
68 }
69 
70 // Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait
71 // until it shows up in the pool.
WorkerEntry(int32_t worker_id)72 Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
73   TaskManager::FindMe()->Post();
74   TensorRow new_row;
75   auto child_iterator = std::make_unique<ChildIterator>(this, worker_id, kCacheHitChildIdx);
76   RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&new_row));
77   while (!new_row.eof()) {
78     if (new_row.eoe()) {
79       RETURN_IF_NOT_OK(EoeReceived(worker_id));
80       RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&new_row));
81     } else {
82       if (new_row.empty()) {
83         auto row_id = new_row.getId();
84         // Block until the row shows up in the pool.
85         RETURN_IF_NOT_OK(cache_miss_.PopFront(row_id, &new_row));
86       }
87       RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row), worker_id));
88       RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&new_row));
89     }
90   }
91   RETURN_IF_NOT_OK(EofReceived(worker_id));
92   return Status::OK();
93 }
94 
CacheMissWorkerEntry(int32_t workerId)95 Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
96   TaskManager::FindMe()->Post();
97   // We will simply pop TensorRow from the stream and insert them into the pool and
98   // wake up any worker that is awaiting on the missing TensorRow.
99   // If we see an eoe, ignore it. For eof, we exit.
100   // Before we start, cache the schema at the server. Pick one of the workers
101   // do it. The schema should have been done at prepare time.
102   if (workerId == 0) {
103     RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
104   }
105   TensorRow new_row;
106   auto child_iterator = std::make_unique<ChildIterator>(this, workerId, kCacheMissChildIdx);
107   RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&new_row));
108   while (!new_row.eof()) {
109     if (new_row.eoe()) {
110       // Ignore it.
111       MS_LOG(DEBUG) << "Ignore eoe";
112       // However we need to flush any left over from the async write buffer. But any error
113       // we are getting will just to stop caching but the pipeline will continue
114       Status rc = cache_client_->FlushAsyncWriteBuffer();
115       if (rc.IsError()) {
116         cache_missing_rows_ = false;
117         if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
118           cache_client_->ServerRunningOutOfResources();
119         } else {
120           MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString();
121         }
122       }
123     } else {
124       row_id_type row_id = new_row.getId();
125       if (row_id < 0) {
126         std::string errMsg = "Expect positive row id, but got: " + std::to_string(row_id);
127         RETURN_STATUS_UNEXPECTED(errMsg);
128       }
129       if (cache_missing_rows_) {
130         // Technically number of this row shows up in the cache miss stream is equal to the number
131         // of P() call. However the cleaner wants it too. So we need an extra copy.
132         TensorRowCacheRequest *rq;
133         RETURN_IF_NOT_OK(GetRq(row_id, &rq));
134         if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) {
135           // We will send the request async. But any error we most
136           // likely ignore and continue.
137           Status rc = rq->AsyncSendCacheRequest(cache_client_, new_row);
138           if (rc.IsOk()) {
139             RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
140           } else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
141             cache_missing_rows_ = false;
142             cache_client_->ServerRunningOutOfResources();
143           }
144         }
145       }
146       RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(new_row)));
147     }
148     RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&new_row));
149   }
150   return Status::OK();
151 }
152 
Cleaner()153 Status CacheMergeOp::Cleaner() {
154   TaskManager::FindMe()->Post();
155   while (true) {
156     row_id_type row_id;
157     RETURN_IF_NOT_OK(io_que_->PopFront(&row_id));
158     if (row_id < 0) {
159       break;
160     }
161     // Locate the cache request
162     TensorRowCacheRequest *rq;
163     RETURN_IF_NOT_OK(GetRq(row_id, &rq));
164     // If already flushed, move on to the next one.
165     if (rq->GetState() == TensorRowCacheRequest::State::kClean) {
166       continue;
167     }
168     Status rc = rq->CheckCacheResult();
169     if (rc.IsError()) {
170       // If interrupt, time to quit.
171       if (rc == StatusCode::kMDInterrupted) {
172         return Status::OK();
173       } else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
174         // The server is hitting some limit and we will turn off caching from now on.
175         cache_missing_rows_ = false;
176         cache_client_->ServerRunningOutOfResources();
177       } else {
178         MS_LOG(INFO) << "Cache row not successful: " << rc.ToString();
179         // Bad rc should not bring down the pipeline. We will simply continue and
180         // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY.
181         rq->SetState(TensorRowCacheRequest::State::kEmpty);
182       }
183     }
184   }
185   return Status::OK();
186 }
187 
PrepareOperator()188 Status CacheMergeOp::PrepareOperator() {  // Run any common code from super class first before adding our own
189                                           // specific logic
190   CHECK_FAIL_RETURN_UNEXPECTED(
191     child_.size() == kNumChildren,
192     "Incorrect number of children of CacheMergeOp, required num is 2, but got:" + std::to_string(child_.size()));
193   RETURN_IF_NOT_OK(DatasetOp::PrepareOperator());
194   // Get the computed check sum from all ops in the cache miss class
195   uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]);
196   // This is a mappable cache op so the id's need to be generated.
197   // Construct the cache
198   const bool generate_ids = false;
199   Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
200   if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
201     // We are told the cache has been created already.
202     MS_LOG(INFO) << "Cache created already";
203     rc = Status::OK();
204   }
205   RETURN_IF_NOT_OK(rc);
206   return Status::OK();
207 }
208 
ComputeColMap()209 Status CacheMergeOp::ComputeColMap() {
210   CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Invalid data, cache miss stream is empty.");
211   if (column_name_id_map().empty()) {
212     column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map();
213   }
214   CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(),
215                                "Invalid data, column_name_id_map of CacheMergeOp is empty.");
216   return Status::OK();
217 }
218 
EoeReceived(int32_t worker_id)219 Status CacheMergeOp::EoeReceived(int32_t worker_id) {
220   // Send the eoe up.
221   MS_LOG(DEBUG) << "Cache merge sending eoe";
222   return out_connector_->SendEOE(worker_id);
223 }
224 
225 // Base-class override for handling cases when an eof is received.
EofReceived(int32_t worker_id)226 Status CacheMergeOp::EofReceived(int32_t worker_id) {
227   // Send the eof up.
228   MS_LOG(DEBUG) << "Cache merge sending eof";
229   return out_connector_->SendEOF(worker_id);
230 }
231 
GetRq(row_id_type row_id,CacheMergeOp::TensorRowCacheRequest ** out)232 Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheRequest **out) {
233   RETURN_UNEXPECTED_IF_NULL(out);
234   std::unique_lock<std::mutex> lock(mux_);
235   auto it = io_request_.find(row_id);
236   if (it != io_request_.end()) {
237     *out = it->second.GetMutablePointer();
238   } else {
239     // We will create a new one.
240     auto alloc = SystemPool::GetAllocator<TensorRowCacheRequest>();
241     auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc));
242     if (r.second) {
243       auto &mem = r.first->second;
244       RETURN_IF_NOT_OK(mem.allocate(1));
245       *out = mem.GetMutablePointer();
246     } else {
247       RETURN_STATUS_UNEXPECTED("Invalid data, map insert fail.");
248     }
249   }
250   return Status::OK();
251 }
252 
AsyncSendCacheRequest(const std::shared_ptr<CacheClient> & cc,const TensorRow & row)253 Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc,
254                                                                   const TensorRow &row) {
255   auto expected = State::kEmpty;
256   if (st_.compare_exchange_strong(expected, State::kDirty)) {
257     // We will do a deep copy but write directly into CacheRequest protobuf or shared memory
258     Status rc = cc->AsyncWriteRow(row);
259     if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
260       cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
261       rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
262       if (rc.IsOk()) {
263         // Send the request async. The cleaner will check the return code.
264         rc = cc->PushRequest(cleaner_copy_);
265       }
266     } else if (rc.IsOk()) {
267       // Set the state to clean even though it still sits in the cache client async buffer.
268       // The cleaner will then ignore it once the state is clean.
269       st_ = State::kClean;
270     }
271     if (rc.IsError()) {
272       // Clean up the shared pointer and reset the state back to empty
273       cleaner_copy_.reset();
274       st_ = State::kEmpty;
275     }
276     return rc;
277   }
278   return Status::OK();
279 }
280 
CheckCacheResult()281 Status CacheMergeOp::TensorRowCacheRequest::CheckCacheResult() {
282   auto expected = State::kDirty;
283   if (st_.compare_exchange_strong(expected, State::kClean)) {
284     // Success or not, we will release the memory.
285     // We simply move it out of the structure and let it go out of scope.
286     auto cache_request = std::move(cleaner_copy_);
287     RETURN_IF_NOT_OK(cache_request->Wait());
288     return Status::OK();
289   }
290   return Status::OK();
291 }
292 }  // namespace dataset
293 }  // namespace mindspore
294