1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
18
19 #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
20 #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
21 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
22 #include "minddata/dataset/engine/ir/datasetops/source/samplers/mindrecord_sampler_ir.h"
23 #include "minddata/dataset/engine/opt/pass.h"
24 #include "minddata/dataset/util/status.h"
25
26 namespace mindspore {
27 namespace dataset {
28
MindDataNode(const std::vector<std::string> & dataset_files,const std::vector<std::string> & columns_list,const std::shared_ptr<SamplerObj> & sampler,nlohmann::json padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,std::shared_ptr<DatasetCache> cache)29 MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
30 const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded,
31 ShuffleMode shuffle_mode, std::shared_ptr<DatasetCache> cache)
32 : MappableSourceNode(std::move(cache)),
33 dataset_file_(std::string()),
34 dataset_files_(dataset_files),
35 search_for_pattern_(false),
36 columns_list_(columns_list),
37 input_sampler_(sampler),
38 sampler_(std::make_shared<MindRecordSamplerObj>()),
39 padded_sample_(padded_sample),
40 sample_bytes_({}),
41 num_padded_(num_padded),
42 shuffle_mode_(shuffle_mode) {}
43
MindDataNode(const std::string & dataset_file,const std::vector<std::string> & columns_list,const std::shared_ptr<SamplerObj> & sampler,nlohmann::json padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,std::shared_ptr<DatasetCache> cache)44 MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
45 const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded,
46 ShuffleMode shuffle_mode, std::shared_ptr<DatasetCache> cache)
47 : MappableSourceNode(std::move(cache)),
48 dataset_file_(dataset_file),
49 dataset_files_({}),
50 search_for_pattern_(true),
51 columns_list_(columns_list),
52 input_sampler_(sampler),
53 sampler_(std::make_shared<MindRecordSamplerObj>()),
54 padded_sample_(padded_sample),
55 sample_bytes_({}),
56 num_padded_(num_padded),
57 shuffle_mode_(shuffle_mode) {}
58
Copy()59 std::shared_ptr<DatasetNode> MindDataNode::Copy() {
60 std::shared_ptr<MindDataNode> node;
61 std::shared_ptr<SamplerObj> sampler = (input_sampler_ == nullptr) ? nullptr : input_sampler_->SamplerCopy();
62 if (dataset_files_.empty()) {
63 node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_,
64 shuffle_mode_, cache_);
65 } else {
66 node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_,
67 shuffle_mode_, cache_);
68 }
69 node->SetSampleBytes(&sample_bytes_);
70 return node;
71 }
72
Print(std::ostream & out) const73 void MindDataNode::Print(std::ostream &out) const { out << (Name() + "(file:" + dataset_file_ + ",...)"); }
74
ValidateParams()75 Status MindDataNode::ValidateParams() {
76 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
77 constexpr size_t max_len = 4096;
78 if (!search_for_pattern_ && dataset_files_.size() > max_len) {
79 std::string err_msg =
80 "MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " +
81 std::to_string(dataset_file_.size());
82 MS_LOG(ERROR) << err_msg;
83 RETURN_STATUS_SYNTAX_ERROR(err_msg);
84 }
85
86 if (shuffle_mode_ != ShuffleMode::kFalse && shuffle_mode_ != ShuffleMode::kFiles &&
87 shuffle_mode_ != ShuffleMode::kGlobal && shuffle_mode_ != ShuffleMode::kInfile) {
88 std::string err_msg = "TFRecordNode: Invalid ShuffleMode, check input value of enum.";
89 MS_LOG(ERROR) << err_msg;
90 RETURN_STATUS_SYNTAX_ERROR(err_msg);
91 }
92
93 std::vector<std::string> dataset_file_vec =
94 search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_;
95 RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataNode", dataset_file_vec));
96
97 RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataNode", input_sampler_));
98
99 if (!columns_list_.empty()) {
100 RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataNode", "columns_list", columns_list_));
101 }
102
103 if (padded_sample_ != nullptr) {
104 if (num_padded_ < 0 || num_padded_ > INT_MAX) {
105 std::string err_msg =
106 "MindDataNode: num_padded must to be between 0 and INT32_MAX, but got: " + std::to_string(num_padded_);
107 MS_LOG(ERROR) << err_msg;
108 RETURN_STATUS_SYNTAX_ERROR(err_msg);
109 }
110 if (columns_list_.empty()) {
111 std::string err_msg = "MindDataNode: padded_sample is specified and requires columns_list as well";
112 MS_LOG(ERROR) << err_msg;
113 RETURN_STATUS_SYNTAX_ERROR(err_msg);
114 }
115 for (std::string &column : columns_list_) {
116 if (padded_sample_.find(column) == padded_sample_.end()) {
117 std::string err_msg = "MindDataNode: " + column + " in columns_list does not match any column in padded_sample";
118 MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_;
119 RETURN_STATUS_SYNTAX_ERROR(err_msg);
120 }
121 }
122 }
123 if (num_padded_ > 0) {
124 if (padded_sample_ == nullptr) {
125 std::string err_msg = "MindDataNode: num_padded is specified but padded_sample is not";
126 MS_LOG(ERROR) << err_msg;
127 RETURN_STATUS_SYNTAX_ERROR(err_msg);
128 }
129 }
130
131 return Status::OK();
132 }
133
134 // Helper function to create runtime sampler for minddata dataset
BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> & sampler,std::vector<std::shared_ptr<mindrecord::ShardOperator>> * operators_,int64_t num_padded,ShuffleMode shuffle_mode)135 Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
136 std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
137 int64_t num_padded, ShuffleMode shuffle_mode) {
138 std::shared_ptr<mindrecord::ShardOperator> op = sampler->BuildForMindDataset();
139 if (op == nullptr) {
140 std::string err_msg =
141 "MindDataNode: Unsupported sampler is supplied for MindDataset. Supported sampler list: "
142 "SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler";
143 MS_LOG(ERROR) << err_msg;
144 RETURN_STATUS_SYNTAX_ERROR(err_msg);
145 }
146 std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
147 while (op != nullptr) {
148 // update the shuffle mode for sampler op or shuffle op
149 if (shuffle_mode != ShuffleMode::kFalse) {
150 op->UpdateShuffleMode(shuffle_mode);
151 }
152
153 auto distributed_sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
154 if (distributed_sampler_op && num_padded > 0) {
155 distributed_sampler_op->SetNumPaddedSamples(num_padded);
156 stack_ops.push(distributed_sampler_op);
157 } else {
158 stack_ops.push(op);
159 }
160 op = op->GetChildOp();
161 }
162 while (!stack_ops.empty()) {
163 operators_->push_back(stack_ops.top());
164 stack_ops.pop();
165 }
166 return Status::OK();
167 }
168
169 // Helper function to set sample_bytes from py::byte type
SetSampleBytes(std::map<std::string,std::string> * sample_bytes)170 void MindDataNode::SetSampleBytes(std::map<std::string, std::string> *sample_bytes) { sample_bytes_ = *sample_bytes; }
171
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)172 Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
173 RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(input_sampler_, &operators_, num_padded_, shuffle_mode_));
174
175 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
176 // Build the sampler IR into a runtime sampler.
177 // This will also create a shard reader object, saved in this node's sampler_.
178 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
179
180 // Now we need to acquire the newly created shard reader from this node's sampler_.
181 // There are two cases:
182 // 1. If this node is cached, now after cache transform pass, its sampler_ has already been replaced by cache lookup
183 // node, and we should find the shard reader from cache lookup node's sampler_.
184 // 2. If this node is not cached, just acquire the shard reader from this node's sampler_.
185 std::unique_ptr<ShardReader> shard_reader;
186 if (IsDescendantOfCache()) {
187 auto cache_lookup_sampler = std::dynamic_pointer_cast<CacheLookupNode>(sampler_);
188 CHECK_FAIL_RETURN_UNEXPECTED(cache_lookup_sampler != nullptr,
189 "Internal error. MindDataNode is cached, its sampler should be cache lookup node");
190 auto mr_sampler = std::dynamic_pointer_cast<MindRecordSamplerObj>(cache_lookup_sampler->Sampler());
191 CHECK_FAIL_RETURN_UNEXPECTED(mr_sampler != nullptr,
192 "Internal error. CacheLookupNode's sampler should be a MindRecordSamplerObj object");
193 RETURN_IF_NOT_OK(mr_sampler->GetShardReader(&shard_reader));
194 } else {
195 auto mr_sampler = std::dynamic_pointer_cast<MindRecordSamplerObj>(sampler_);
196 CHECK_FAIL_RETURN_UNEXPECTED(mr_sampler != nullptr,
197 "Internal error. MindDataNode's sampler should be a MindRecordSamplerObj object");
198 RETURN_IF_NOT_OK(mr_sampler->GetShardReader(&shard_reader));
199 }
200
201 std::shared_ptr<MindRecordOp> mindrecord_op;
202 // If pass a string to MindData(), it will be treated as a pattern to search for matched files,
203 // else if pass a vector to MindData(), it will be treated as specified files to be read
204 if (search_for_pattern_) {
205 std::vector<std::string> dataset_file_vec_ = {dataset_file_};
206 mindrecord_op = std::make_shared<MindRecordOp>(
207 num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_, columns_list_, operators_, num_padded_,
208 padded_sample_, sample_bytes_, shuffle_mode_, std::move(shard_reader), std::move(sampler_rt));
209 } else {
210 mindrecord_op = std::make_shared<MindRecordOp>(
211 num_workers_, dataset_files_, search_for_pattern_, connector_que_size_, columns_list_, operators_, num_padded_,
212 padded_sample_, sample_bytes_, shuffle_mode_, std::move(shard_reader), std::move(sampler_rt));
213 }
214
215 RETURN_IF_NOT_OK(mindrecord_op->Init());
216 mindrecord_op->SetTotalRepeats(GetTotalRepeats());
217 mindrecord_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
218 node_ops->push_back(mindrecord_op);
219
220 return Status::OK();
221 }
222
223 // Get the shard id of node
GetShardId(int32_t * shard_id)224 Status MindDataNode::GetShardId(int32_t *shard_id) {
225 *shard_id = input_sampler_->ShardId();
226
227 return Status::OK();
228 }
229
230 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)231 Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
232 int64_t *dataset_size) {
233 if (dataset_size_ > 0) {
234 *dataset_size = dataset_size_;
235 return Status::OK();
236 }
237 int64_t num_rows = -1;
238 std::vector<std::shared_ptr<ShardOperator>> operators;
239 RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(input_sampler_, &operators, num_padded_, shuffle_mode_));
240
241 if (search_for_pattern_) {
242 dataset_files_ = {dataset_file_};
243 }
244
245 // The last operator is parent sampler
246 std::shared_ptr<ShardOperator> op = operators.back();
247 RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_));
248 *dataset_size = num_rows;
249 dataset_size_ = *dataset_size;
250 return Status::OK();
251 }
252
253 // Visitor accepting method for IRNodePass
Accept(IRNodePass * const p,bool * const modified)254 Status MindDataNode::Accept(IRNodePass *const p, bool *const modified) {
255 // Downcast shared pointer then call visitor
256 return p->Visit(shared_from_base<MindDataNode>(), modified);
257 }
258
259 // Visitor accepting method for IRNodePass
AcceptAfter(IRNodePass * const p,bool * const modified)260 Status MindDataNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
261 // Downcast shared pointer then call visitor
262 return p->VisitAfter(shared_from_base<MindDataNode>(), modified);
263 }
264 } // namespace dataset
265 } // namespace mindspore
266