• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/cache_base_op.h"
17 
18 #include "minddata/dataset/engine/execution_tree.h"
19 
20 namespace mindspore {
21 namespace dataset {
22 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const23 void CacheBase::Print(std::ostream &out, bool show_all) const {
24   if (!show_all) {
25     // Call the super class for displaying any common 1-liner info
26     ParallelOp::Print(out, show_all);
27     out << "\n";
28   } else {
29     // Call the super class for displaying any common detailed info
30     ParallelOp::Print(out, show_all);
31     // Then show any custom derived-internal stuff
32     out << "\nCache client:\n" << *cache_client_ << "\n\n";
33   }
34 }
35 // Overrides base class reset method.  When an operator does a reset, it cleans up any state
36 // info from it's previous execution and then initializes itself so that it can be executed
37 // again.
Reset()38 Status CacheBase::Reset() {
39   if (sampler_ != nullptr) {
40     RETURN_IF_NOT_OK(sampler_->ResetSampler());
41   }
42   // Wake up the workers to get them going again in a new epoch
43   MS_LOG(DEBUG) << Name() << " performing a self-reset.";
44   return Status::OK();
45 }
CacheBase(int32_t num_workers,int32_t op_connector_size,std::shared_ptr<CacheClient> cache_client,std::shared_ptr<SamplerRT> sampler)46 CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
47                      std::shared_ptr<SamplerRT> sampler)
48     : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
49       row_cnt_(0),
50       num_cache_miss_(0),
51       cache_client_(std::move(cache_client)),
52       prefetch_size_(1),
53       num_prefetchers_(num_workers_) {
54   // Adjust the prefetch size based on the number of workers.
55   auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_;
56   if (prefetch_size_ < prefetch_sz_per_thread) {
57     prefetch_size_ = prefetch_sz_per_thread;
58     MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_;
59   }
60   io_block_queues_.Init(num_workers, op_connector_size);
61   prefetch_queues_.Init(num_prefetchers_, op_connector_size);
62   // We can cause deadlock if this internal Connector size is too small.
63   keys_miss_ = std::make_unique<Connector<std::vector<row_id_type>>>(num_prefetchers_, 1, connector_capacity_);
64 }
65 // Common function to fetch samples from the sampler and send them using the io_block_queues to
66 // the parallel workers
FetchSamplesToWorkers()67 Status CacheBase::FetchSamplesToWorkers() {
68   int64_t buf_cnt = 0;
69   int64_t wait_cnt = 0;
70   int64_t prefetch_cnt = 0;
71   // Kick off several threads which will prefetch prefetch_size_ rows in advance.
72   RETURN_UNEXPECTED_IF_NULL(tree_);
73   RETURN_IF_NOT_OK(
74     tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
75   auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
76                         std::vector<row_id_type> &keys) -> Status {
77     auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
78     RETURN_IF_NOT_OK(qList[worker_id]->Add(std::move(blk)));
79     return Status::OK();
80   };
81   // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
82   // to the WorkerEntry.
83   do {
84     if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % GetOpNumRepeatsPerEpoch() == 0) {
85       MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_
86                    << " Total number of rows : " << row_cnt_;
87     }
88     num_cache_miss_ = 0;
89     row_cnt_ = 0;
90     ++wait_cnt;
91     std::vector<row_id_type> keys;
92     keys.reserve(1);
93     std::vector<row_id_type> prefetch_keys;
94     prefetch_keys.reserve(prefetch_size_);
95     TensorRow sample_row;
96     RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
97     while (!sample_row.eoe()) {
98       std::shared_ptr<Tensor> sample_ids = sample_row[0];
99       for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
100         ++row_cnt_;
101         prefetch_keys.push_back(*itr);
102         // Batch enough rows for performance reason.
103         if (row_cnt_ % prefetch_size_ == 0) {
104           RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
105           // Now we tell the WorkerEntry to wait for them to come back.
106           for (auto row_id : prefetch_keys) {
107             keys.push_back(row_id);
108             RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
109             keys.clear();
110           }
111           prefetch_keys.clear();
112         }
113       }
114       RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
115     }
116     // Deal with any partial keys left.
117     if (!prefetch_keys.empty()) {
118       RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
119       for (auto row_id : prefetch_keys) {
120         keys.push_back(row_id);
121         RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
122         keys.clear();
123       }
124     }
125     if (!keys.empty()) {
126       RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
127     }
128     // send the eoe
129     RETURN_IF_NOT_OK(
130       io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
131     RETURN_IF_NOT_OK(prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add(
132       std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
133     // If repeat but the not last repeat, wait for reset.
134     if (!IsLastIteration()) {
135       MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt;
136     } else {
137       // We can break out from the loop.
138       break;
139     }
140     if (epoch_sync_flag_) {
141       // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
142       // the current epoch.
143       RETURN_IF_NOT_OK(WaitForWorkers());
144     }
145     // If not the last repeat, self-reset and go to loop again.
146     if (!IsLastIteration()) RETURN_IF_NOT_OK(Reset());
147     UpdateRepeatAndEpochCounter();
148   } while (true);
149   // Flow the eof before exit
150   RETURN_IF_NOT_OK(
151     io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
152   // Shutdown threads
153   for (int32_t i = 0; i < num_workers_; i++) {
154     RETURN_IF_NOT_OK(
155       io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
156   }
157   // Dump the last epoch result (approximately) without waiting for the worker threads to come back.
158   if (AllowCacheMiss()) {
159     MS_LOG(INFO) << "Epoch: " << wait_cnt / GetOpNumRepeatsPerEpoch() << " Cache Miss : " << num_cache_miss_
160                  << " Total number of rows : " << row_cnt_;
161   }
162   return Status::OK();
163 }
164 
FetchFromCache(int32_t worker_id)165 Status CacheBase::FetchFromCache(int32_t worker_id) {
166   std::unique_ptr<IOBlock> blk;
167   do {
168     RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
169     if (blk->wait()) {
170       // Sync io_block is a signal that master thread wants us to pause and sync with other workers.
171       // The last guy who comes to this sync point should reset the counter and wake up the master thread.
172       if (++num_workers_paused_ == num_workers_) {
173         wait_for_workers_post_.Set();
174       }
175     } else if (blk->eof()) {
176       RETURN_IF_NOT_OK(out_connector_->SendEOF(worker_id));
177     } else if (blk->eoe()) {
178       RETURN_IF_NOT_OK(out_connector_->SendEOE(worker_id));
179     } else {
180       std::vector<int64_t> keys;
181       RETURN_IF_NOT_OK(blk->GetKeys(&keys));
182       if (keys.empty()) {
183         // empty key is a quit signal for workers
184         break;
185       }
186       for (auto row_id : keys) {
187         TensorRow row;
188         // Block until the row shows up in the pool.
189         RETURN_IF_NOT_OK(GetPrefetchRow(row_id, &row));
190         if (row.empty()) {
191           if (AllowCacheMiss()) {
192             ++num_cache_miss_;
193           } else {
194             std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
195             RETURN_STATUS_UNEXPECTED(errMsg);
196           }
197         }
198         RETURN_IF_NOT_OK(out_connector_->Add(std::move(row), worker_id));
199       }
200     }
201   } while (true);
202   return Status::OK();
203 }
204 
RegisterResources()205 Status CacheBase::RegisterResources() {
206   RETURN_UNEXPECTED_IF_NULL(tree_);
207   RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
208   RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
209   RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
210   return Status::OK();
211 }
212 
213 CacheBase::~CacheBase() = default;
214 
UpdateColumnMapFromCache()215 Status CacheBase::UpdateColumnMapFromCache() {
216   Status rc;
217   // Get the schema from the server. It may not be there yet. So tolerate the error.
218   if (column_name_id_map_.empty()) {
219     rc = cache_client_->FetchSchema(&column_name_id_map_);
220     if (rc == Status(StatusCode::kMDFileNotExist)) {
221       MS_LOG(DEBUG) << "Schema not in the server yet.";
222       rc = Status::OK();
223     }
224   }
225   return rc;
226 }
227 
GetPrefetchRow(row_id_type row_id,TensorRow * out)228 Status CacheBase::GetPrefetchRow(row_id_type row_id, TensorRow *out) {
229   RETURN_UNEXPECTED_IF_NULL(out);
230   CHECK_FAIL_RETURN_UNEXPECTED(row_id >= 0, "Expect positive row id, but got:" + std::to_string(row_id));
231   RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, out));
232   return Status::OK();
233 }
234 
PrefetchRows(const std::vector<row_id_type> & keys,std::vector<row_id_type> * cache_miss)235 Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss) {
236   RETURN_UNEXPECTED_IF_NULL(cache_miss);
237   std::vector<row_id_type> prefetch_keys;
238   prefetch_keys.reserve(keys.size());
239 
240   // Filter out all those keys that unlikely we will find at the server
241   for (auto row_id : keys) {
242     if (cache_client_->KeyIsCacheMiss(row_id)) {
243       // Just put an empty row in the cache.
244       TensorRow row;
245       row.setId(row_id);
246       RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
247       cache_miss->push_back(row_id);
248     } else {
249       prefetch_keys.push_back(row_id);
250     }
251   }
252   // Early exit if nothing to fetch
253   if (prefetch_keys.empty()) {
254     return Status::OK();
255   }
256   // Get the rows from the server
257   TensorTable ttbl;
258   RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl));
259   auto row_it = ttbl.begin();
260   for (auto row_id : prefetch_keys) {
261     auto &row = *row_it;
262     if (row.empty()) {
263       cache_miss->push_back(row_id);
264     }
265     // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
266     RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
267     ++row_it;
268   }
269   return Status::OK();
270 }
271 
Prefetcher(int32_t worker_id)272 Status CacheBase::Prefetcher(int32_t worker_id) {
273   TaskManager::FindMe()->Post();
274   std::vector<row_id_type> prefetch_keys;
275   prefetch_keys.reserve(prefetch_size_);
276   std::vector<row_id_type> cache_miss;
277   cache_miss.reserve(prefetch_size_);
278   do {
279     prefetch_keys.clear();
280     cache_miss.clear();
281     std::unique_ptr<IOBlock> blk;
282     RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk));
283     CHECK_FAIL_RETURN_UNEXPECTED(!blk->eof(), "Expect eoe or a regular io block.");
284     if (!blk->eoe()) {
285       RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys));
286       Status rc;
287       const int32_t max_retries = 5;
288       int32_t retry_count = 0;
289       do {
290         rc = PrefetchRows(prefetch_keys, &cache_miss);
291         if (rc == StatusCode::kMDNetWorkError && retry_count < max_retries) {
292           // If we get some network error, we will attempt some retries
293           retry_count++;
294         } else if (rc.IsError() && rc.StatusCode() != StatusCode::kMDInterrupted) {
295           MS_LOG(WARNING) << rc.ToString();
296           return rc;
297         }
298       } while (rc == StatusCode::kMDNetWorkError);
299       // In case any thread is waiting for the rows to come back and blocked on a semaphore,
300       // we will put an empty row in the local cache.
301       if (rc.IsError() && AllowCacheMiss()) {
302         for (auto row_id : prefetch_keys) {
303           TensorRow row;
304           row.setId(row_id);
305           RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
306           cache_miss.push_back(row_id);
307         }
308       }
309     } else {
310       if (AllowCacheMiss()) {
311         // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
312         // a sampler, send a eoe to physical leaf op as well.
313         cache_miss.push_back(eoe_row_id);
314       }
315     }
316     if (AllowCacheMiss()) {
317       // Because of the way connector works, we push unconditionally even cache_miss can be empty.
318       RETURN_IF_NOT_OK(keys_miss_->Push(worker_id, cache_miss));
319     }
320   } while (true);
321   return Status::OK();
322 }
323 }  // namespace dataset
324 }  // namespace mindspore
325