• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/cache_base_op.h"
17 
18 #include "minddata/dataset/engine/execution_tree.h"
19 
20 namespace mindspore {
21 namespace dataset {
22 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const23 void CacheBase::Print(std::ostream &out, bool show_all) const {
24   if (!show_all) {
25     // Call the super class for displaying any common 1-liner info
26     ParallelOp::Print(out, show_all);
27     out << "\n";
28   } else {
29     // Call the super class for displaying any common detailed info
30     ParallelOp::Print(out, show_all);
31     // Then show any custom derived-internal stuff
32     out << "\nCache client:\n" << *cache_client_ << "\n\n";
33   }
34 }
35 // Overrides base class reset method.  When an operator does a reset, it cleans up any state
36 // info from it's previous execution and then initializes itself so that it can be executed
37 // again.
Reset()38 Status CacheBase::Reset() {
39   if (sampler_ != nullptr) {
40     RETURN_IF_NOT_OK(sampler_->ResetSampler());
41   }
42   // Wake up the workers to get them going again in a new epoch
43   MS_LOG(DEBUG) << Name() << " performing a self-reset.";
44   return Status::OK();
45 }
CacheBase(int32_t num_workers,int32_t op_connector_size,std::shared_ptr<CacheClient> cache_client,std::shared_ptr<SamplerRT> sampler)46 CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
47                      std::shared_ptr<SamplerRT> sampler)
48     : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
49       row_cnt_(0),
50       num_cache_miss_(0),
51       cache_client_(std::move(cache_client)),
52       prefetch_size_(1),
53       num_prefetchers_(num_workers_) {
54   // Adjust the prefetch size based on the number of workers.
55   auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_;
56   if (prefetch_size_ < prefetch_sz_per_thread) {
57     prefetch_size_ = prefetch_sz_per_thread;
58     MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_;
59   }
60   worker_in_queues_.Init(num_workers, op_connector_size);
61   prefetch_queues_.Init(num_prefetchers_, op_connector_size);
62   // We can cause deadlock if this internal Connector size is too small.
63   keys_miss_ = std::make_unique<Connector<std::vector<row_id_type>>>(num_prefetchers_, 1, connector_capacity_);
64 }
65 // Common function to fetch samples from the sampler and send them using the io_block_queues to
66 // the parallel workers
FetchSamplesToWorkers()67 Status CacheBase::FetchSamplesToWorkers() {
68   int64_t buf_cnt = 0;
69   int64_t wait_cnt = 0;
70   int64_t prefetch_cnt = 0;
71   // Kick off several threads which will prefetch cache_prefetch_size_ rows in advance.
72   RETURN_UNEXPECTED_IF_NULL(tree_);
73   RETURN_IF_NOT_OK(
74     tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
75   auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
76                         std::vector<row_id_type> &keys) -> Status {
77     auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kFlagNone));
78     RETURN_IF_NOT_OK(qList[worker_id]->Add(std::move(blk)));
79     return Status::OK();
80   };
81   // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
82   // to the WorkerEntry.
83   do {
84     if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % GetOpNumRepeatsPerEpoch() == 0) {
85       MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_
86                    << " Total number of rows : " << row_cnt_;
87     }
88     num_cache_miss_ = 0;
89     row_cnt_ = 0;
90     ++wait_cnt;
91     std::vector<row_id_type> keys;
92     keys.reserve(1);
93     std::vector<row_id_type> prefetch_keys;
94     prefetch_keys.reserve(prefetch_size_);
95     TensorRow sample_row;
96     RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
97     while (!sample_row.eoe()) {
98       std::shared_ptr<Tensor> sample_ids = sample_row[0];
99       for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
100         ++row_cnt_;
101         prefetch_keys.push_back(*itr);
102         // Batch enough rows for performance reason.
103         if (row_cnt_ % prefetch_size_ == 0) {
104           RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
105           // Now we tell the WorkerEntry to wait for them to come back.
106           for (auto row_id : prefetch_keys) {
107             keys.push_back(row_id);
108             RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, static_cast<int32_t>(buf_cnt++ % num_workers_), keys));
109             keys.clear();
110           }
111           prefetch_keys.clear();
112         }
113       }
114       RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
115     }
116     // Deal with any partial keys left.
117     if (!prefetch_keys.empty()) {
118       RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
119       for (auto row_id : prefetch_keys) {
120         keys.push_back(row_id);
121         RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, static_cast<int32_t>(buf_cnt++ % num_workers_), keys));
122         keys.clear();
123       }
124     }
125     if (!keys.empty()) {
126       RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, static_cast<int32_t>(buf_cnt++ % num_workers_), keys));
127     }
128     // send the eoe
129     RETURN_IF_NOT_OK(worker_in_queues_[static_cast<const int>((buf_cnt++) % num_workers_)]->Add(
130       std::make_unique<IOBlock>(IOBlock::kFlagEOE)));
131     RETURN_IF_NOT_OK(
132       prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add(std::make_unique<IOBlock>(IOBlock::kFlagEOE)));
133     // If repeat but the not last repeat, wait for reset.
134     if (!IsLastIteration()) {
135       MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt;
136     } else {
137       // We can break out from the loop.
138       break;
139     }
140     if (epoch_sync_flag_) {
141       // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
142       // the current epoch.
143       RETURN_IF_NOT_OK(WaitForWorkers());
144     }
145     // If not the last repeat, self-reset and go to loop again.
146     if (!IsLastIteration()) {
147       RETURN_IF_NOT_OK(Reset());
148     }
149     UpdateRepeatAndEpochCounter();
150   } while (true);
151   // Flow the eof before exit
152   RETURN_IF_NOT_OK(worker_in_queues_[static_cast<const int>((buf_cnt++) % num_workers_)]->Add(
153     std::make_unique<IOBlock>(IOBlock::kFlagEOF)));
154   // Shutdown threads
155   for (int32_t i = 0; i < num_workers_; i++) {
156     RETURN_IF_NOT_OK(worker_in_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kFlagNone)));
157   }
158   // Dump the last epoch result (approximately) without waiting for the worker threads to come back.
159   if (AllowCacheMiss()) {
160     MS_LOG(INFO) << "Epoch: " << wait_cnt / GetOpNumRepeatsPerEpoch() << " Cache Miss : " << num_cache_miss_
161                  << " Total number of rows : " << row_cnt_;
162   }
163   return Status::OK();
164 }
165 
FetchFromCache(int32_t worker_id)166 Status CacheBase::FetchFromCache(int32_t worker_id) {
167   std::unique_ptr<IOBlock> blk;
168   do {
169     RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&blk));
170     if (blk->wait()) {
171       // Sync io_block is a signal that master thread wants us to pause and sync with other workers.
172       // The last guy who comes to this sync point should reset the counter and wake up the master thread.
173       if (++num_workers_paused_ == num_workers_) {
174         wait_for_workers_post_.Set();
175       }
176     } else if (blk->eof()) {
177       RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOF)));
178     } else if (blk->eoe()) {
179       RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOE)));
180     } else {
181       std::vector<int64_t> keys;
182       RETURN_IF_NOT_OK(blk->GetKeys(&keys));
183       if (keys.empty()) {
184         // empty key is a quit signal for workers
185         break;
186       }
187       for (auto row_id : keys) {
188         TensorRow row;
189         // Block until the row shows up in the pool.
190         RETURN_IF_NOT_OK(GetPrefetchRow(row_id, &row));
191         if (row.empty()) {
192           if (AllowCacheMiss()) {
193             ++num_cache_miss_;
194           } else {
195             std::string errMsg = "[Internal ERROR] Row id " + std::to_string(row_id) + " not found.";
196             RETURN_STATUS_UNEXPECTED(errMsg);
197           }
198         }
199         RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(row)));
200       }
201     }
202   } while (true);
203   return Status::OK();
204 }
205 
RegisterResources()206 Status CacheBase::RegisterResources() {
207   RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
208   RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
209   return Status::OK();
210 }
211 
212 CacheBase::~CacheBase() = default;
213 
UpdateColumnMapFromCache()214 Status CacheBase::UpdateColumnMapFromCache() {
215   Status rc;
216   // Get the schema from the server. It may not be there yet. So tolerate the error.
217   if (column_name_id_map_.empty()) {
218     rc = cache_client_->FetchSchema(&column_name_id_map_);
219     if (rc == Status(StatusCode::kMDFileNotExist)) {
220       MS_LOG(DEBUG) << "Schema not in the server yet.";
221       rc = Status::OK();
222     }
223   }
224   return rc;
225 }
226 
GetPrefetchRow(row_id_type row_id,TensorRow * out)227 Status CacheBase::GetPrefetchRow(row_id_type row_id, TensorRow *out) {
228   RETURN_UNEXPECTED_IF_NULL(out);
229   CHECK_FAIL_RETURN_UNEXPECTED(row_id >= 0,
230                                "[Internal ERROR] Expect positive row id, but got:" + std::to_string(row_id));
231   RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, out));
232   return Status::OK();
233 }
234 
PrefetchRows(const std::vector<row_id_type> & keys,std::vector<row_id_type> * cache_miss)235 Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss) {
236   RETURN_UNEXPECTED_IF_NULL(cache_miss);
237   std::vector<row_id_type> prefetch_keys;
238   prefetch_keys.reserve(keys.size());
239 
240   // Filter out all those keys that unlikely we will find at the server
241   for (auto row_id : keys) {
242     if (cache_client_->KeyIsCacheMiss(row_id)) {
243       // Just put an empty row in the cache.
244       TensorRow row;
245       row.setId(row_id);
246       RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
247       cache_miss->push_back(row_id);
248     } else {
249       prefetch_keys.push_back(row_id);
250     }
251   }
252   // Early exit if nothing to fetch
253   if (prefetch_keys.empty()) {
254     return Status::OK();
255   }
256   // Get the rows from the server
257   TensorTable ttbl;
258   RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl));
259   auto row_it = ttbl.begin();
260   for (auto row_id : prefetch_keys) {
261     auto &row = *row_it;
262     if (row.empty()) {
263       cache_miss->push_back(row_id);
264     }
265     // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
266     RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
267     ++row_it;
268   }
269   return Status::OK();
270 }
271 
Prefetcher(int32_t worker_id)272 Status CacheBase::Prefetcher(int32_t worker_id) {
273   TaskManager::FindMe()->Post();
274   std::vector<row_id_type> prefetch_keys;
275   prefetch_keys.reserve(prefetch_size_);
276   std::vector<row_id_type> cache_miss;
277   cache_miss.reserve(prefetch_size_);
278   do {
279     prefetch_keys.clear();
280     cache_miss.clear();
281     std::unique_ptr<IOBlock> blk;
282     RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk));
283     CHECK_FAIL_RETURN_UNEXPECTED(!blk->eof(), "[Internal ERROR] Expect eoe or a regular io block.");
284     if (!blk->eoe()) {
285       RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys));
286       Status rc;
287       const int32_t max_retries = 5;
288       int32_t retry_count = 0;
289       do {
290         rc = PrefetchRows(prefetch_keys, &cache_miss);
291         if (rc == StatusCode::kMDNetWorkError && retry_count < max_retries) {
292           // If we get some network error, we will attempt some retries
293           retry_count++;
294         } else if (rc.IsError() && rc.StatusCode() != StatusCode::kMDInterrupted) {
295           MS_LOG(WARNING) << rc.ToString();
296           return rc;
297         }
298       } while (rc == StatusCode::kMDNetWorkError);
299       // In case any thread is waiting for the rows to come back and blocked on a semaphore,
300       // we will put an empty row in the local cache.
301       if (rc.IsError() && AllowCacheMiss()) {
302         for (auto row_id : prefetch_keys) {
303           TensorRow row;
304           row.setId(row_id);
305           RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
306           cache_miss.push_back(row_id);
307         }
308       }
309     } else {
310       if (AllowCacheMiss()) {
311         // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
312         // a sampler, send a eoe to physical leaf op as well.
313         cache_miss.push_back(eoe_row_id);
314       }
315     }
316     if (AllowCacheMiss()) {
317       // Because of the way connector works, we push unconditionally even cache_miss can be empty.
318       RETURN_IF_NOT_OK(keys_miss_->Push(worker_id, cache_miss));
319     }
320   } while (true);
321   return Status::OK();
322 }
323 }  // namespace dataset
324 }  // namespace mindspore
325