1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
18 #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
19 #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
20 #include "minddata/dataset/engine/datasetops/cache_op.h"
21
22 namespace mindspore {
23 namespace dataset {
24 /// Method to initialize the DatasetCache by creating an instance of a CacheClient
25 /// \return Status Error code
Build()26 Status DatasetCacheImpl::Build() {
27 // The same DatasetCache instance can be re-used for multiple pipelines for cache sharing,
28 // in this case, cache_client_ object might have been created.
29 if (cache_client_) return Status::OK();
30
31 CacheClient::Builder builder;
32 builder.SetSessionId(session_id_).SetCacheMemSz(cache_mem_sz_).SetSpill(spill_);
33 if (hostname_) {
34 (void)builder.SetHostname(hostname_.value());
35 }
36 if (port_) {
37 (void)builder.SetPort(port_.value());
38 }
39 if (num_connections_) {
40 (void)builder.SetNumConnections(num_connections_.value());
41 }
42 if (prefetch_sz_) {
43 (void)builder.SetPrefetchSize(prefetch_sz_.value());
44 }
45 return builder.Build(&cache_client_);
46 }
47
CreateCacheOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetOp> * ds)48 Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, int32_t connector_queue_size,
49 std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
50 CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheOp requires a CacheClient, but got nullptr.");
51 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
52 RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
53 std::shared_ptr<CacheOp> cache_op =
54 std::make_shared<CacheOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
55 *ds = cache_op;
56
57 return Status::OK();
58 }
59
CreateCacheLookupOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<SamplerObj> sampler,std::shared_ptr<DatasetOp> * ds)60 Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
61 std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
62 CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheLookupOp requires a CacheClient, but got nullptr.");
63 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
64 RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
65 std::shared_ptr<CacheLookupOp> lookup_op =
66 std::make_shared<CacheLookupOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
67 *ds = lookup_op;
68
69 return Status::OK();
70 }
71
CreateCacheMergeOp(int32_t num_workers,int32_t connector_queue_size,std::shared_ptr<DatasetOp> * ds)72 Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
73 std::shared_ptr<DatasetOp> *ds) {
74 CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheMergeOp requires a CacheClient, but got nullptr.");
75 std::shared_ptr<CacheMergeOp> merge_op =
76 std::make_shared<CacheMergeOp>(num_workers, connector_queue_size, num_workers, cache_client_);
77 *ds = merge_op;
78
79 return Status::OK();
80 }
81
to_json(nlohmann::json * out_json)82 Status DatasetCacheImpl::to_json(nlohmann::json *out_json) {
83 nlohmann::json args;
84 args["session_id"] = session_id_;
85 args["cache_memory_size"] = cache_mem_sz_;
86 args["spill"] = spill_;
87 if (hostname_) args["hostname"] = hostname_.value();
88 if (port_) args["port"] = port_.value();
89 if (num_connections_) args["num_connections"] = num_connections_.value();
90 if (prefetch_sz_) args["prefetch_size"] = prefetch_sz_.value();
91 *out_json = args;
92 return Status::OK();
93 }
94 } // namespace dataset
95 } // namespace mindspore
96