• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
18 #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
19 #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
20 #include "minddata/dataset/engine/datasetops/cache_op.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 /// Method to initialize the DatasetCache by creating an instance of a CacheClient
25 /// \return Status Error code
Build()26 Status DatasetCacheImpl::Build() {
27   // The same DatasetCache instance can be re-used for multiple pipelines for cache sharing,
28   // in this case, cache_client_ object might have been created.
29   if (cache_client_) return Status::OK();
30 
31   CacheClient::Builder builder;
32   builder.SetSessionId(session_id_).SetCacheMemSz(cache_mem_sz_).SetSpill(spill_);
33   if (hostname_) {
34     (void)builder.SetHostname(hostname_.value());
35   }
36   if (port_) {
37     (void)builder.SetPort(port_.value());
38   }
39   if (num_connections_) {
40     (void)builder.SetNumConnections(num_connections_.value());
41   }
42   if (prefetch_sz_) {
43     (void)builder.SetPrefetchSize(prefetch_sz_.value());
44   }
45   return builder.Build(&cache_client_);
46 }
47 
CreateCacheOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetOp> * ds)48 Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, int32_t connector_queue_size,
49                                        std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
50   CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheOp requires a CacheClient, but got nullptr.");
51   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
52   RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
53   std::shared_ptr<CacheOp> cache_op =
54     std::make_shared<CacheOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
55   *ds = cache_op;
56 
57   return Status::OK();
58 }
59 
CreateCacheLookupOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetOp> * ds)60 Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
61                                              std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
62   CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheLookupOp requires a CacheClient, but got nullptr.");
63   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
64   RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
65   std::shared_ptr<CacheLookupOp> lookup_op =
66     std::make_shared<CacheLookupOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
67   *ds = lookup_op;
68 
69   return Status::OK();
70 }
71 
CreateCacheMergeOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<DatasetOp> * ds)72 Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
73                                             std::shared_ptr<DatasetOp> *ds) {
74   CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheMergeOp requires a CacheClient, but got nullptr.");
75   std::shared_ptr<CacheMergeOp> merge_op =
76     std::make_shared<CacheMergeOp>(num_workers, connector_queue_size, num_workers, cache_client_);
77   *ds = merge_op;
78 
79   return Status::OK();
80 }
81 
to_json(nlohmann::json * out_json)82 Status DatasetCacheImpl::to_json(nlohmann::json *out_json) {
83   nlohmann::json args;
84   args["session_id"] = session_id_;
85   args["cache_memory_size"] = cache_mem_sz_;
86   args["spill"] = spill_;
87   if (hostname_) args["hostname"] = hostname_.value();
88   if (port_) args["port"] = port_.value();
89   if (num_connections_) args["num_connections"] = num_connections_.value();
90   if (prefetch_sz_) args["prefetch_size"] = prefetch_sz_.value();
91   *out_json = args;
92   return Status::OK();
93 }
94 }  // namespace dataset
95 }  // namespace mindspore
96