1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/cache/cache_client.h"
18 #include "minddata/dataset/engine/cache/cache_fbb.h"
19 #include "minddata/dataset/engine/cache/cache_request.h"
20 #include "minddata/dataset/util/bit.h"
21 #include "minddata/dataset/util/task_manager.h"
22
23 namespace mindspore {
24 namespace dataset {
Builder()25 CacheClient::Builder::Builder()
26 : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_connections_(0), prefetch_size_(0) {
27 std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
28 hostname_ = cfg->cache_host();
29 port_ = cfg->cache_port();
30 num_connections_ = cfg->num_connections(); // number of async tcp/ip connections
31 prefetch_size_ = cfg->cache_prefetch_size(); // prefetch size
32 }
33
Build(std::shared_ptr<CacheClient> * out)34 Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) {
35 RETURN_UNEXPECTED_IF_NULL(out);
36 RETURN_IF_NOT_OK(SanityCheck());
37 *out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_connections_,
38 prefetch_size_);
39 return Status::OK();
40 }
41
SanityCheck()42 Status CacheClient::Builder::SanityCheck() {
43 CHECK_FAIL_RETURN_SYNTAX_ERROR(session_id_ > 0, "session id must be positive.");
44 CHECK_FAIL_RETURN_SYNTAX_ERROR(cache_mem_sz_ >= 0, "cache memory size must not be negative (0 implies unlimited).");
45 CHECK_FAIL_RETURN_SYNTAX_ERROR(num_connections_ > 0, "number of tcp/ip connections must be positive.");
46 CHECK_FAIL_RETURN_SYNTAX_ERROR(prefetch_size_ > 0, "prefetch size must be positive.");
47 CHECK_FAIL_RETURN_SYNTAX_ERROR(!hostname_.empty(), "hostname must not be empty.");
48 CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ >= kMinLegalPort, "Port must be in range (1025..65535).");
49 CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ <= kMaxLegalPort, "Port must be in range (1025..65535).");
50 CHECK_FAIL_RETURN_SYNTAX_ERROR(hostname_ == "127.0.0.1",
51 "now cache client has to be on the same host with cache server.");
52 return Status::OK();
53 }
54
55 // Constructor
CacheClient(session_id_type session_id,uint64_t cache_mem_sz,bool spill,std::string hostname,int32_t port,int32_t num_connections,int32_t prefetch_size)56 CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname,
57 int32_t port, int32_t num_connections, int32_t prefetch_size)
58 : cache_mem_sz_(cache_mem_sz),
59 spill_(spill),
60 server_connection_id_(0),
61 client_id_(-1),
62 local_bypass_(false),
63 num_connections_(num_connections),
64 prefetch_size_(prefetch_size),
65 fetch_all_keys_(true) {
66 cinfo_.set_session_id(session_id);
67 comm_ = std::make_shared<CacheClientGreeter>(hostname, port, num_connections_);
68 }
69
~CacheClient()70 CacheClient::~CacheClient() {
71 cache_miss_keys_wp_.Set();
72 // Manually release the async buffer because we need the comm layer.
73 if (async_buffer_stream_) {
74 Status rc = async_buffer_stream_->ReleaseBuffer();
75 if (rc.IsError()) {
76 MS_LOG(ERROR) << rc;
77 }
78 }
79 if (client_id_ != -1) {
80 try {
81 // Send a message to the server, saying I am done.
82 auto rq = std::make_shared<ConnectResetRequest>(server_connection_id_, client_id_);
83 Status rc = PushRequest(rq);
84 if (rc.IsOk()) {
85 rc = rq->Wait();
86 if (rc.IsOk()) {
87 MS_LOG(INFO) << "Disconnect from server successful";
88 }
89 }
90 } catch (const std::exception &e) {
91 // Can't do anything in destructor. So just log the error.
92 MS_LOG(ERROR) << e.what();
93 }
94 }
95 (void)comm_->ServiceStop();
96 }
97
98 // print method for display cache details
Print(std::ostream & out) const99 void CacheClient::Print(std::ostream &out) const {
100 out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc()
101 << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << GetCacheMemSz()
102 << "\n Spilling: " << std::boolalpha << isSpill() << "\n Number of rpc workers: " << GetNumConnections()
103 << "\n Prefetch size: " << GetPrefetchSize() << "\n Local client support: " << std::boolalpha
104 << SupportLocalClient();
105 }
106
GetHostname() const107 std::string CacheClient::GetHostname() const { return comm_->GetHostname(); }
GetPort() const108 int32_t CacheClient::GetPort() const { return comm_->GetPort(); }
109
WriteRow(const TensorRow & row,row_id_type * row_id_from_server) const110 Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
111 auto rq = std::make_shared<CacheRowRequest>(this);
112 RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row));
113 RETURN_IF_NOT_OK(PushRequest(rq));
114 RETURN_IF_NOT_OK(rq->Wait());
115 if (row_id_from_server != nullptr) {
116 *row_id_from_server = rq->GetRowIdAfterCache();
117 }
118 return Status::OK();
119 }
120
AsyncWriteRow(const TensorRow & row)121 Status CacheClient::AsyncWriteRow(const TensorRow &row) {
122 if (async_buffer_stream_ == nullptr) {
123 return Status(StatusCode::kMDNotImplementedYet);
124 }
125 RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row));
126 return Status::OK();
127 }
128
GetRows(const std::vector<row_id_type> & row_id,TensorTable * out) const129 Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
130 RETURN_UNEXPECTED_IF_NULL(out);
131 auto rq = std::make_shared<BatchFetchRequest>(this, row_id);
132 RETURN_IF_NOT_OK(PushRequest(rq));
133 RETURN_IF_NOT_OK(rq->Wait());
134 int64_t mem_addr;
135 Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr);
136 // Free the memory by sending a request back to the server.
137 if (mem_addr != -1) {
138 auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, client_id_, mem_addr);
139 Status rc2 = PushRequest(mfree_req);
140 // But we won't wait for the result for the sake of performance.
141 if (rc.IsOk() && rc2.IsError()) {
142 rc = rc2;
143 }
144 }
145 return rc;
146 }
147
CreateCache(uint32_t tree_crc,bool generate_id)148 Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
149 UniqueLock lck(&mux_);
150 // To create a cache, we identify ourself at the client by:
151 // - the shared session id
152 // - a crc for the tree nodes from the cache downward
153 // Pack these 2 into a single 64 bit request id
154 //
155 // Consider this example:
156 // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
157 // tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch
158 // These are different trees in a single session, but the user wants to share the cache.
159 // This is not allowed because the data of these caches are different.
160 //
161 // Consider this example:
162 // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
163 // tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch
164 // These are different trees in the same session, but the cached data is the same, so it is okay
165 // to allow the sharing of this cache between these pipelines.
166
167 // The CRC is computed by the tree prepare phase and passed to this function when creating the cache.
168 // If we already have a server_connection_id_, then it means this same cache client has already been used
169 // to create a cache and some other tree is trying to use the same cache.
170 // That is allowed, however the crc better match!
171 if (server_connection_id_) {
172 if (cinfo_.crc() != tree_crc) {
173 RETURN_STATUS_UNEXPECTED("Cannot re-use a cache for a different tree!");
174 }
175 // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
176 // skip the build phase.
177 lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
178 int8_t out;
179 RETURN_IF_NOT_OK(GetState(&out));
180 auto cache_state = static_cast<CacheServiceState>(out);
181 if (cache_state == CacheServiceState::kFetchPhase ||
182 (cache_state == CacheServiceState::kBuildPhase && cookie_.empty())) {
183 RETURN_STATUS_ERROR(StatusCode::kMDDuplicateKey, "Not an error and we should bypass the build phase");
184 }
185 if (async_buffer_stream_) {
186 // Reset the async buffer stream to its initial state. Any stale status and data would get cleaned up.
187 RETURN_IF_NOT_OK(async_buffer_stream_->Reset());
188 }
189 } else {
190 cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
191 // Now execute the cache create request using this identifier and other configs
192 CreateCacheRequest::CreateCacheFlag createFlag = CreateCacheRequest::CreateCacheFlag::kNone;
193 if (spill_) {
194 createFlag |= CreateCacheRequest::CreateCacheFlag::kSpillToDisk;
195 }
196 if (generate_id) {
197 createFlag |= CreateCacheRequest::CreateCacheFlag::kGenerateRowId;
198 }
199 // Start the comm layer to receive reply
200 RETURN_IF_NOT_OK(comm_->ServiceStart());
201 // Initiate connection
202 auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag);
203 RETURN_IF_NOT_OK(PushRequest(rq));
204 Status rc = rq->Wait();
205 bool success = (rc.IsOk() || rc.StatusCode() == StatusCode::kMDDuplicateKey);
206 // If we get kDuplicateKey, it just means we aren't the first one to create the cache,
207 // and we will continue to parse the result.
208 if (rc.StatusCode() == StatusCode::kMDDuplicateKey) {
209 RETURN_IF_NOT_OK(rq->PostReply());
210 }
211 if (success) {
212 // Attach to shared memory for local client
213 RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(&local_bypass_));
214 if (local_bypass_) {
215 async_buffer_stream_ = std::make_shared<AsyncBufferStream>();
216 RETURN_IF_NOT_OK(async_buffer_stream_->Init(this));
217 }
218 }
219 // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
220 // CacheOp to bypass the build phase.
221 return rc;
222 }
223 return Status::OK();
224 }
225
DestroyCache()226 Status CacheClient::DestroyCache() {
227 UniqueLock lck(&mux_);
228 auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_);
229 RETURN_IF_NOT_OK(PushRequest(rq));
230 RETURN_IF_NOT_OK(rq->Wait());
231 return Status::OK();
232 }
233
GetStat(CacheServiceStat * stat)234 Status CacheClient::GetStat(CacheServiceStat *stat) {
235 SharedLock lck(&mux_);
236 RETURN_UNEXPECTED_IF_NULL(stat);
237 // GetStat has an external interface, so we have to make sure we have a valid connection id first
238 CHECK_FAIL_RETURN_UNEXPECTED(server_connection_id_ != 0, "GetStat called but the cache is not in use yet.");
239
240 auto rq = std::make_shared<GetStatRequest>(server_connection_id_);
241 RETURN_IF_NOT_OK(PushRequest(rq));
242 RETURN_IF_NOT_OK(rq->Wait());
243 rq->GetStat(stat);
244 return Status::OK();
245 }
246
GetState(int8_t * out)247 Status CacheClient::GetState(int8_t *out) {
248 SharedLock lck(&mux_);
249 RETURN_UNEXPECTED_IF_NULL(out);
250 CHECK_FAIL_RETURN_UNEXPECTED(server_connection_id_ != 0, "GetState called but the cache is not in use yet.");
251 auto rq = std::make_shared<GetCacheStateRequest>(server_connection_id_);
252 RETURN_IF_NOT_OK(PushRequest(rq));
253 RETURN_IF_NOT_OK(rq->Wait());
254 *out = rq->GetState();
255 return Status::OK();
256 }
257
CacheSchema(const std::unordered_map<std::string,int32_t> & map)258 Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
259 SharedLock lck(&mux_);
260 auto rq = std::make_shared<CacheSchemaRequest>(server_connection_id_);
261 RETURN_IF_NOT_OK(rq->SerializeCacheSchemaRequest(map));
262 RETURN_IF_NOT_OK(PushRequest(rq));
263 RETURN_IF_NOT_OK(rq->Wait());
264 return Status::OK();
265 }
266
FetchSchema(std::unordered_map<std::string,int32_t> * map)267 Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) {
268 SharedLock lck(&mux_);
269 RETURN_UNEXPECTED_IF_NULL(map);
270 auto rq = std::make_shared<FetchSchemaRequest>(server_connection_id_);
271 RETURN_IF_NOT_OK(PushRequest(rq));
272 RETURN_IF_NOT_OK(rq->Wait());
273 *map = rq->GetColumnMap();
274 return Status::OK();
275 }
276
BuildPhaseDone() const277 Status CacheClient::BuildPhaseDone() const {
278 SharedLock lck(&mux_);
279 auto rq = std::make_shared<BuildPhaseDoneRequest>(server_connection_id_, cookie());
280 RETURN_IF_NOT_OK(PushRequest(rq));
281 RETURN_IF_NOT_OK(rq->Wait());
282 return Status::OK();
283 }
284
PushRequest(std::shared_ptr<BaseRequest> rq) const285 Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); }
286
ServerRunningOutOfResources()287 void CacheClient::ServerRunningOutOfResources() {
288 bool expected = true;
289 if (fetch_all_keys_.compare_exchange_strong(expected, false)) {
290 Status rc;
291 // Server runs out of memory or disk space to cache any more rows.
292 // First of all, we will turn off the locking.
293 auto toggle_write_mode_rq = std::make_shared<ToggleWriteModeRequest>(server_connection_id_, false);
294 rc = PushRequest(toggle_write_mode_rq);
295 if (rc.IsError()) {
296 return;
297 }
298 // Wait until we can toggle the state of the server to non-locking
299 rc = toggle_write_mode_rq->Wait();
300 if (rc.IsError()) {
301 return;
302 }
303 // Now we get a list of all the keys not cached at the server so
304 // we can filter out at the prefetch level.
305 auto cache_miss_rq = std::make_shared<GetCacheMissKeysRequest>(server_connection_id_);
306 rc = PushRequest(cache_miss_rq);
307 if (rc.IsError()) {
308 return;
309 }
310 rc = cache_miss_rq->Wait();
311 if (rc.IsError()) {
312 return;
313 }
314 // We will get back a vector of row id between [min,max] that are absent in the server.
315 auto &row_id_buf = cache_miss_rq->reply_.result();
316 auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
317 std::vector<row_id_type> row_ids;
318 auto sz = p->row_id()->size();
319 row_ids.reserve(sz);
320 for (uint32_t i = 0; i < sz; ++i) {
321 row_ids.push_back(p->row_id()->Get(i));
322 }
323 cache_miss_keys_ = std::make_unique<CacheMissKeys>(row_ids);
324 // We are all set.
325 cache_miss_keys_wp_.Set();
326 }
327 }
328
CacheMissKeys(const std::vector<row_id_type> & v)329 CacheClient::CacheMissKeys::CacheMissKeys(const std::vector<row_id_type> &v) {
330 auto it = v.begin();
331 min_ = *it;
332 ++it;
333 max_ = *it;
334 ++it;
335 while (it != v.end()) {
336 gap_.insert(*it);
337 ++it;
338 }
339 MS_LOG(INFO) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size();
340 }
341
KeyIsCacheMiss(row_id_type key)342 bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
343 if (key > max_ || key < min_) {
344 return true;
345 } else if (key == min_ || key == max_) {
346 return false;
347 } else {
348 auto it = gap_.find(key);
349 return it != gap_.end();
350 }
351 }
352
AsyncBufferStream()353 CacheClient::AsyncBufferStream::AsyncBufferStream()
354 : cc_(nullptr), offset_addr_(-1), cur_(0), buf_arr_(std::vector<AsyncWriter>(kNumAsyncBuffer)) {}
355
~AsyncBufferStream()356 CacheClient::AsyncBufferStream::~AsyncBufferStream() {
357 (void)vg_.ServiceStop();
358 (void)ReleaseBuffer();
359 }
360
ReleaseBuffer()361 Status CacheClient::AsyncBufferStream::ReleaseBuffer() {
362 if (offset_addr_ != -1) {
363 auto mfree_req =
364 std::make_shared<FreeSharedBlockRequest>(cc_->server_connection_id_, cc_->GetClientId(), offset_addr_);
365 offset_addr_ = -1;
366 RETURN_IF_NOT_OK(cc_->PushRequest(mfree_req));
367 RETURN_IF_NOT_OK(mfree_req->Wait());
368 }
369 return Status::OK();
370 }
371
Init(CacheClient * cc)372 Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) {
373 cc_ = cc;
374 // Allocate shared memory from the server
375 auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(cc_->server_connection_id_, cc_->GetClientId(),
376 kAsyncBufferSize * kNumAsyncBuffer);
377 RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
378 RETURN_IF_NOT_OK(mem_rq->Wait());
379 offset_addr_ = mem_rq->GetAddr();
380 // Now we need to add that to the base address of where we attach.
381 auto base = cc->SharedMemoryBaseAddr();
382 auto start = reinterpret_cast<int64_t>(base) + offset_addr_;
383 for (auto i = 0; i < kNumAsyncBuffer; ++i) {
384 // We only need to set the pointer during init. Other fields will be set dynamically.
385 buf_arr_[i].buffer_ = reinterpret_cast<void *>(start + i * kAsyncBufferSize);
386 }
387 buf_arr_[0].bytes_avail_ = kAsyncBufferSize;
388 buf_arr_[0].num_ele_ = 0;
389 RETURN_IF_NOT_OK(vg_.ServiceStart());
390 return Status::OK();
391 }
392
AsyncWrite(const TensorRow & row)393 Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
394 std::vector<ReadableSlice> v;
395 v.reserve(row.size() + 1);
396 std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
397 RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
398 int64_t sz = fbb->GetSize();
399 v.emplace_back(fbb->GetBufferPointer(), sz);
400 for (const auto &ts : row) {
401 sz += ts->SizeInBytes();
402 v.emplace_back(ts->GetBuffer(), ts->SizeInBytes());
403 }
404 // If the size is too big, tell the user to send it directly.
405 if (sz > kAsyncBufferSize) {
406 return Status(StatusCode::kMDNotImplementedYet);
407 }
408 std::unique_lock<std::mutex> lock(mux_);
409 // Check error from the server side while we have the lock;
410 RETURN_IF_NOT_OK(flush_rc_);
411 AsyncWriter *asyncWriter = &buf_arr_[cur_];
412 if (asyncWriter->bytes_avail_ < sz) {
413 // Flush and switch to a new buffer while we have the lock.
414 RETURN_IF_NOT_OK(SyncFlush(AsyncFlushFlag::kCallerHasXLock));
415 // Refresh the pointer after we switch
416 asyncWriter = &buf_arr_[cur_];
417 }
418 RETURN_IF_NOT_OK(asyncWriter->Write(sz, v));
419 return Status::OK();
420 }
421
SyncFlush(AsyncFlushFlag flag)422 Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) {
423 std::unique_lock lock(mux_, std::defer_lock_t());
424 bool callerHasXLock = (flag & AsyncFlushFlag::kCallerHasXLock) == AsyncFlushFlag::kCallerHasXLock;
425 if (!callerHasXLock) {
426 lock.lock();
427 }
428 auto *asyncWriter = &buf_arr_[cur_];
429 if (asyncWriter->num_ele_) {
430 asyncWriter->rq.reset(
431 new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
432 flush_rc_ = cc_->PushRequest(asyncWriter->rq);
433 RETURN_IF_NOT_OK(flush_rc_);
434
435 // If we are asked to wait, say this is the final flush, just wait for its completion.
436 bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
437 if (blocking) {
438 // Make sure we are done with all the buffers
439 for (auto i = 0; i < kNumAsyncBuffer; ++i) {
440 if (buf_arr_[i].rq) {
441 Status rc = buf_arr_[i].rq->Wait();
442 if (rc.IsError()) {
443 flush_rc_ = rc;
444 }
445 buf_arr_[i].rq.reset();
446 }
447 }
448 }
449 // Prepare for the next buffer.
450 cur_ = (cur_ + 1) % kNumAsyncBuffer;
451 asyncWriter = &buf_arr_[cur_];
452 // Update the cur_ while we have the lock.
453 // Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
454 // Also we can also pick up any error from previous flush.
455 if (asyncWriter->rq) {
456 // Save the result into a common area, so worker can see it and quit.
457 flush_rc_ = asyncWriter->rq->Wait();
458 asyncWriter->rq.reset();
459 }
460 asyncWriter->bytes_avail_ = kAsyncBufferSize;
461 asyncWriter->num_ele_ = 0;
462 }
463
464 return flush_rc_;
465 }
466
Write(int64_t sz,const std::vector<ReadableSlice> & v)467 Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std::vector<ReadableSlice> &v) {
468 CHECK_FAIL_RETURN_UNEXPECTED(sz <= bytes_avail_, "Programming error");
469 for (auto &p : v) {
470 auto write_sz = p.GetSize();
471 WritableSlice dest(reinterpret_cast<char *>(buffer_) + kAsyncBufferSize - bytes_avail_, write_sz);
472 RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p));
473 bytes_avail_ -= write_sz;
474 }
475 ++num_ele_;
476 return Status::OK();
477 }
478
Reset()479 Status CacheClient::AsyncBufferStream::Reset() {
480 // Clean up previous running state to be prepared for a new run.
481 cur_ = 0;
482 flush_rc_ = Status::OK();
483 for (auto i = 0; i < kNumAsyncBuffer; ++i) {
484 buf_arr_[i].bytes_avail_ = kAsyncBufferSize;
485 buf_arr_[i].num_ele_ = 0;
486 buf_arr_[i].rq.reset();
487 }
488 return Status::OK();
489 }
490 } // namespace dataset
491 } // namespace mindspore
492