1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
18 #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
19 #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
20 #include "minddata/dataset/engine/datasetops/cache_op.h"
21
22 namespace mindspore {
23 namespace dataset {
24 /// Method to initialize the DatasetCache by creating an instance of a CacheClient
25 /// \return Status Error code
Build()26 Status DatasetCacheImpl::Build() {
27 // The same DatasetCache instance can be re-used for multiple pipelines for cache sharing,
28 // in this case, cache_client_ object might have been created.
29 if (cache_client_) {
30 return Status::OK();
31 }
32
33 CacheClient::Builder builder;
34 builder.SetSessionId(session_id_).SetCacheMemSz(cache_mem_sz_).SetSpill(spill_);
35 if (hostname_) {
36 (void)builder.SetHostname(hostname_.value());
37 }
38 if (port_) {
39 (void)builder.SetPort(port_.value());
40 }
41 if (num_connections_) {
42 (void)builder.SetNumConnections(num_connections_.value());
43 }
44 if (prefetch_sz_) {
45 (void)builder.SetPrefetchSize(prefetch_sz_.value());
46 }
47 return builder.Build(&cache_client_);
48 }
49
CreateCacheOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetOp> * ds)50 Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, int32_t connector_queue_size,
51 std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
52 CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheOp requires a CacheClient, but got nullptr.");
53 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
54 RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
55 std::shared_ptr<CacheOp> cache_op =
56 std::make_shared<CacheOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
57 *ds = cache_op;
58
59 return Status::OK();
60 }
61
CreateCacheLookupOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetOp> * ds)62 Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
63 std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
64 CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheLookupOp requires a CacheClient, but got nullptr.");
65 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
66 RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
67 std::shared_ptr<CacheLookupOp> lookup_op =
68 std::make_shared<CacheLookupOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
69 *ds = lookup_op;
70
71 return Status::OK();
72 }
73
CreateCacheMergeOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<DatasetOp> * ds)74 Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
75 std::shared_ptr<DatasetOp> *ds) {
76 CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheMergeOp requires a CacheClient, but got nullptr.");
77 std::shared_ptr<CacheMergeOp> merge_op =
78 std::make_shared<CacheMergeOp>(num_workers, connector_queue_size, num_workers, cache_client_);
79 *ds = merge_op;
80
81 return Status::OK();
82 }
83
to_json(nlohmann::json * out_json)84 Status DatasetCacheImpl::to_json(nlohmann::json *out_json) {
85 nlohmann::json args;
86 args["session_id"] = session_id_;
87 args["cache_memory_size"] = cache_mem_sz_;
88 args["spill"] = spill_;
89 if (hostname_) {
90 args["hostname"] = hostname_.value();
91 }
92 if (port_) {
93 args["port"] = port_.value();
94 }
95 if (num_connections_) {
96 args["num_connections"] = num_connections_.value();
97 }
98 if (prefetch_sz_) {
99 args["cache_prefetch_size"] = prefetch_sz_.value();
100 }
101 *out_json = args;
102 return Status::OK();
103 }
104 } // namespace dataset
105 } // namespace mindspore
106