• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
18 #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
19 #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
20 #include "minddata/dataset/engine/datasetops/cache_op.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 /// Method to initialize the DatasetCache by creating an instance of a CacheClient
25 /// \return Status Error code
Build()26 Status DatasetCacheImpl::Build() {
27   // The same DatasetCache instance can be re-used for multiple pipelines for cache sharing,
28   // in this case, cache_client_ object might have been created.
29   if (cache_client_) {
30     return Status::OK();
31   }
32 
33   CacheClient::Builder builder;
34   builder.SetSessionId(session_id_).SetCacheMemSz(cache_mem_sz_).SetSpill(spill_);
35   if (hostname_) {
36     (void)builder.SetHostname(hostname_.value());
37   }
38   if (port_) {
39     (void)builder.SetPort(port_.value());
40   }
41   if (num_connections_) {
42     (void)builder.SetNumConnections(num_connections_.value());
43   }
44   if (prefetch_sz_) {
45     (void)builder.SetPrefetchSize(prefetch_sz_.value());
46   }
47   return builder.Build(&cache_client_);
48 }
49 
CreateCacheOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetOp> * ds)50 Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, int32_t connector_queue_size,
51                                        std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
52   CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheOp requires a CacheClient, but got nullptr.");
53   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
54   RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
55   std::shared_ptr<CacheOp> cache_op =
56     std::make_shared<CacheOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
57   *ds = cache_op;
58 
59   return Status::OK();
60 }
61 
CreateCacheLookupOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetOp> * ds)62 Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
63                                              std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
64   CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheLookupOp requires a CacheClient, but got nullptr.");
65   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
66   RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
67   std::shared_ptr<CacheLookupOp> lookup_op =
68     std::make_shared<CacheLookupOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
69   *ds = lookup_op;
70 
71   return Status::OK();
72 }
73 
CreateCacheMergeOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<DatasetOp> * ds)74 Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
75                                             std::shared_ptr<DatasetOp> *ds) {
76   CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheMergeOp requires a CacheClient, but got nullptr.");
77   std::shared_ptr<CacheMergeOp> merge_op =
78     std::make_shared<CacheMergeOp>(num_workers, connector_queue_size, num_workers, cache_client_);
79   *ds = merge_op;
80 
81   return Status::OK();
82 }
83 
to_json(nlohmann::json * out_json)84 Status DatasetCacheImpl::to_json(nlohmann::json *out_json) {
85   nlohmann::json args;
86   args["session_id"] = session_id_;
87   args["cache_memory_size"] = cache_mem_sz_;
88   args["spill"] = spill_;
89   if (hostname_) {
90     args["hostname"] = hostname_.value();
91   }
92   if (port_) {
93     args["port"] = port_.value();
94   }
95   if (num_connections_) {
96     args["num_connections"] = num_connections_.value();
97   }
98   if (prefetch_sz_) {
99     args["cache_prefetch_size"] = prefetch_sz_.value();
100   }
101   *out_json = args;
102   return Status::OK();
103 }
104 }  // namespace dataset
105 }  // namespace mindspore
106