1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/cache_base_op.h"
17
18 #include "minddata/dataset/engine/execution_tree.h"
19
20 namespace mindspore {
21 namespace dataset {
22 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const23 void CacheBase::Print(std::ostream &out, bool show_all) const {
24 if (!show_all) {
25 // Call the super class for displaying any common 1-liner info
26 ParallelOp::Print(out, show_all);
27 out << "\n";
28 } else {
29 // Call the super class for displaying any common detailed info
30 ParallelOp::Print(out, show_all);
31 // Then show any custom derived-internal stuff
32 out << "\nCache client:\n" << *cache_client_ << "\n\n";
33 }
34 }
35 // Overrides base class reset method. When an operator does a reset, it cleans up any state
36 // info from it's previous execution and then initializes itself so that it can be executed
37 // again.
Reset()38 Status CacheBase::Reset() {
39 if (sampler_ != nullptr) {
40 RETURN_IF_NOT_OK(sampler_->ResetSampler());
41 }
42 // Wake up the workers to get them going again in a new epoch
43 MS_LOG(DEBUG) << Name() << " performing a self-reset.";
44 return Status::OK();
45 }
CacheBase(int32_t num_workers,int32_t op_connector_size,std::shared_ptr<CacheClient> cache_client,std::shared_ptr<SamplerRT> sampler)46 CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
47 std::shared_ptr<SamplerRT> sampler)
48 : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
49 row_cnt_(0),
50 num_cache_miss_(0),
51 cache_client_(std::move(cache_client)),
52 prefetch_size_(1),
53 num_prefetchers_(num_workers_) {
54 // Adjust the prefetch size based on the number of workers.
55 auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_;
56 if (prefetch_size_ < prefetch_sz_per_thread) {
57 prefetch_size_ = prefetch_sz_per_thread;
58 MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_;
59 }
60 worker_in_queues_.Init(num_workers, op_connector_size);
61 prefetch_queues_.Init(num_prefetchers_, op_connector_size);
62 // We can cause deadlock if this internal Connector size is too small.
63 keys_miss_ = std::make_unique<Connector<std::vector<row_id_type>>>(num_prefetchers_, 1, connector_capacity_);
64 }
65 // Common function to fetch samples from the sampler and send them using the io_block_queues to
66 // the parallel workers
FetchSamplesToWorkers()67 Status CacheBase::FetchSamplesToWorkers() {
68 int64_t buf_cnt = 0;
69 int64_t wait_cnt = 0;
70 int64_t prefetch_cnt = 0;
71 // Kick off several threads which will prefetch cache_prefetch_size_ rows in advance.
72 RETURN_UNEXPECTED_IF_NULL(tree_);
73 RETURN_IF_NOT_OK(
74 tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1), Name()));
75 auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
76 std::vector<row_id_type> &keys) -> Status {
77 auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kFlagNone));
78 RETURN_IF_NOT_OK(qList[worker_id]->Add(std::move(blk)));
79 return Status::OK();
80 };
81 // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
82 // to the WorkerEntry.
83 do {
84 if (AllowCacheMiss() && wait_cnt > 0 && wait_cnt % GetOpNumRepeatsPerEpoch() == 0) {
85 MS_LOG(INFO) << "Epoch: " << op_current_epochs_ << " Cache Miss : " << num_cache_miss_
86 << " Total number of rows : " << row_cnt_;
87 }
88 num_cache_miss_ = 0;
89 row_cnt_ = 0;
90 ++wait_cnt;
91 std::vector<row_id_type> keys;
92 keys.reserve(1);
93 std::vector<row_id_type> prefetch_keys;
94 prefetch_keys.reserve(prefetch_size_);
95 TensorRow sample_row;
96 RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
97 while (!sample_row.eoe()) {
98 std::shared_ptr<Tensor> sample_ids = sample_row[0];
99 for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
100 ++row_cnt_;
101 prefetch_keys.push_back(*itr);
102 // Batch enough rows for performance reason.
103 if (row_cnt_ % prefetch_size_ == 0) {
104 RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
105 // Now we tell the WorkerEntry to wait for them to come back.
106 for (auto row_id : prefetch_keys) {
107 keys.push_back(row_id);
108 RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, static_cast<int32_t>(buf_cnt++ % num_workers_), keys));
109 keys.clear();
110 }
111 prefetch_keys.clear();
112 }
113 }
114 RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
115 }
116 // Deal with any partial keys left.
117 if (!prefetch_keys.empty()) {
118 RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
119 for (auto row_id : prefetch_keys) {
120 keys.push_back(row_id);
121 RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, static_cast<int32_t>(buf_cnt++ % num_workers_), keys));
122 keys.clear();
123 }
124 }
125 if (!keys.empty()) {
126 RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, static_cast<int32_t>(buf_cnt++ % num_workers_), keys));
127 }
128 // send the eoe
129 RETURN_IF_NOT_OK(worker_in_queues_[static_cast<const int>((buf_cnt++) % num_workers_)]->Add(
130 std::make_unique<IOBlock>(IOBlock::kFlagEOE)));
131 RETURN_IF_NOT_OK(
132 prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add(std::make_unique<IOBlock>(IOBlock::kFlagEOE)));
133 // If repeat but the not last repeat, wait for reset.
134 if (!IsLastIteration()) {
135 MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt;
136 } else {
137 // We can break out from the loop.
138 break;
139 }
140 if (epoch_sync_flag_) {
141 // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
142 // the current epoch.
143 RETURN_IF_NOT_OK(WaitForWorkers());
144 }
145 // If not the last repeat, self-reset and go to loop again.
146 if (!IsLastIteration()) {
147 RETURN_IF_NOT_OK(Reset());
148 }
149 UpdateRepeatAndEpochCounter();
150 } while (true);
151 // Flow the eof before exit
152 RETURN_IF_NOT_OK(worker_in_queues_[static_cast<const int>((buf_cnt++) % num_workers_)]->Add(
153 std::make_unique<IOBlock>(IOBlock::kFlagEOF)));
154 // Shutdown threads
155 for (int32_t i = 0; i < num_workers_; i++) {
156 RETURN_IF_NOT_OK(worker_in_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kFlagNone)));
157 }
158 // Dump the last epoch result (approximately) without waiting for the worker threads to come back.
159 if (AllowCacheMiss()) {
160 MS_LOG(INFO) << "Epoch: " << wait_cnt / GetOpNumRepeatsPerEpoch() << " Cache Miss : " << num_cache_miss_
161 << " Total number of rows : " << row_cnt_;
162 }
163 return Status::OK();
164 }
165
FetchFromCache(int32_t worker_id)166 Status CacheBase::FetchFromCache(int32_t worker_id) {
167 std::unique_ptr<IOBlock> blk;
168 do {
169 RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&blk));
170 if (blk->wait()) {
171 // Sync io_block is a signal that master thread wants us to pause and sync with other workers.
172 // The last guy who comes to this sync point should reset the counter and wake up the master thread.
173 if (++num_workers_paused_ == num_workers_) {
174 wait_for_workers_post_.Set();
175 }
176 } else if (blk->eof()) {
177 RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOF)));
178 } else if (blk->eoe()) {
179 RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOE)));
180 } else {
181 std::vector<int64_t> keys;
182 RETURN_IF_NOT_OK(blk->GetKeys(&keys));
183 if (keys.empty()) {
184 // empty key is a quit signal for workers
185 break;
186 }
187 for (auto row_id : keys) {
188 TensorRow row;
189 // Block until the row shows up in the pool.
190 RETURN_IF_NOT_OK(GetPrefetchRow(row_id, &row));
191 if (row.empty()) {
192 if (AllowCacheMiss()) {
193 ++num_cache_miss_;
194 } else {
195 std::string errMsg = "[Internal ERROR] Row id " + std::to_string(row_id) + " not found.";
196 RETURN_STATUS_UNEXPECTED(errMsg);
197 }
198 }
199 RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(row)));
200 }
201 }
202 } while (true);
203 return Status::OK();
204 }
205
RegisterResources()206 Status CacheBase::RegisterResources() {
207 RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
208 RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
209 return Status::OK();
210 }
211
212 CacheBase::~CacheBase() = default;
213
UpdateColumnMapFromCache()214 Status CacheBase::UpdateColumnMapFromCache() {
215 Status rc;
216 // Get the schema from the server. It may not be there yet. So tolerate the error.
217 if (column_name_id_map_.empty()) {
218 rc = cache_client_->FetchSchema(&column_name_id_map_);
219 if (rc == Status(StatusCode::kMDFileNotExist)) {
220 MS_LOG(DEBUG) << "Schema not in the server yet.";
221 rc = Status::OK();
222 }
223 }
224 return rc;
225 }
226
GetPrefetchRow(row_id_type row_id,TensorRow * out)227 Status CacheBase::GetPrefetchRow(row_id_type row_id, TensorRow *out) {
228 RETURN_UNEXPECTED_IF_NULL(out);
229 CHECK_FAIL_RETURN_UNEXPECTED(row_id >= 0,
230 "[Internal ERROR] Expect positive row id, but got:" + std::to_string(row_id));
231 RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, out));
232 return Status::OK();
233 }
234
PrefetchRows(const std::vector<row_id_type> & keys,std::vector<row_id_type> * cache_miss)235 Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss) {
236 RETURN_UNEXPECTED_IF_NULL(cache_miss);
237 std::vector<row_id_type> prefetch_keys;
238 prefetch_keys.reserve(keys.size());
239
240 // Filter out all those keys that unlikely we will find at the server
241 for (auto row_id : keys) {
242 if (cache_client_->KeyIsCacheMiss(row_id)) {
243 // Just put an empty row in the cache.
244 TensorRow row;
245 row.setId(row_id);
246 RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
247 cache_miss->push_back(row_id);
248 } else {
249 prefetch_keys.push_back(row_id);
250 }
251 }
252 // Early exit if nothing to fetch
253 if (prefetch_keys.empty()) {
254 return Status::OK();
255 }
256 // Get the rows from the server
257 TensorTable ttbl;
258 RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl));
259 auto row_it = ttbl.begin();
260 for (auto row_id : prefetch_keys) {
261 auto &row = *row_it;
262 if (row.empty()) {
263 cache_miss->push_back(row_id);
264 }
265 // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
266 RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
267 ++row_it;
268 }
269 return Status::OK();
270 }
271
Prefetcher(int32_t worker_id)272 Status CacheBase::Prefetcher(int32_t worker_id) {
273 TaskManager::FindMe()->Post();
274 std::vector<row_id_type> prefetch_keys;
275 prefetch_keys.reserve(prefetch_size_);
276 std::vector<row_id_type> cache_miss;
277 cache_miss.reserve(prefetch_size_);
278 do {
279 prefetch_keys.clear();
280 cache_miss.clear();
281 std::unique_ptr<IOBlock> blk;
282 RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk));
283 CHECK_FAIL_RETURN_UNEXPECTED(!blk->eof(), "[Internal ERROR] Expect eoe or a regular io block.");
284 if (!blk->eoe()) {
285 RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys));
286 Status rc;
287 const int32_t max_retries = 5;
288 int32_t retry_count = 0;
289 do {
290 rc = PrefetchRows(prefetch_keys, &cache_miss);
291 if (rc == StatusCode::kMDNetWorkError && retry_count < max_retries) {
292 // If we get some network error, we will attempt some retries
293 retry_count++;
294 } else if (rc.IsError() && rc.StatusCode() != StatusCode::kMDInterrupted) {
295 MS_LOG(WARNING) << rc.ToString();
296 return rc;
297 }
298 } while (rc == StatusCode::kMDNetWorkError);
299 // In case any thread is waiting for the rows to come back and blocked on a semaphore,
300 // we will put an empty row in the local cache.
301 if (rc.IsError() && AllowCacheMiss()) {
302 for (auto row_id : prefetch_keys) {
303 TensorRow row;
304 row.setId(row_id);
305 RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
306 cache_miss.push_back(row_id);
307 }
308 }
309 } else {
310 if (AllowCacheMiss()) {
311 // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
312 // a sampler, send a eoe to physical leaf op as well.
313 cache_miss.push_back(eoe_row_id);
314 }
315 }
316 if (AllowCacheMiss()) {
317 // Because of the way connector works, we push unconditionally even cache_miss can be empty.
318 RETURN_IF_NOT_OK(keys_miss_->Push(worker_id, cache_miss));
319 }
320 } while (true);
321 return Status::OK();
322 }
323 } // namespace dataset
324 } // namespace mindspore
325