1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
18
19 #include <stack>
20 #include <utility>
21
22 #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
24 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
25 #include "minddata/dataset/engine/ir/datasetops/source/samplers/mindrecord_sampler_ir.h"
26 #include "minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h"
27 #include "minddata/dataset/engine/opt/pass.h"
28 #include "minddata/dataset/util/status.h"
29
30 namespace mindspore {
31 namespace dataset {
32
MindDataNode(const std::vector<std::string> & dataset_files,const std::vector<std::string> & columns_list,const std::shared_ptr<SamplerObj> & sampler,nlohmann::json padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,std::shared_ptr<DatasetCache> cache)33 MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
34 const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded,
35 ShuffleMode shuffle_mode, std::shared_ptr<DatasetCache> cache)
36 : MappableSourceNode(std::move(cache)),
37 dataset_file_(std::string()),
38 dataset_files_(dataset_files),
39 search_for_pattern_(false),
40 columns_list_(columns_list),
41 input_sampler_(sampler),
42 sampler_(std::make_shared<MindRecordSamplerObj>()),
43 padded_sample_(padded_sample),
44 sample_bytes_({}),
45 num_padded_(num_padded),
46 shuffle_mode_(shuffle_mode) {}
47
MindDataNode(const std::string & dataset_file,const std::vector<std::string> & columns_list,const std::shared_ptr<SamplerObj> & sampler,nlohmann::json padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,std::shared_ptr<DatasetCache> cache)48 MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
49 const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded,
50 ShuffleMode shuffle_mode, std::shared_ptr<DatasetCache> cache)
51 : MappableSourceNode(std::move(cache)),
52 dataset_file_(dataset_file),
53 dataset_files_({}),
54 search_for_pattern_(true),
55 columns_list_(columns_list),
56 input_sampler_(sampler),
57 sampler_(std::make_shared<MindRecordSamplerObj>()),
58 padded_sample_(padded_sample),
59 sample_bytes_({}),
60 num_padded_(num_padded),
61 shuffle_mode_(shuffle_mode) {}
62
Copy()63 std::shared_ptr<DatasetNode> MindDataNode::Copy() {
64 std::shared_ptr<MindDataNode> node;
65 std::shared_ptr<SamplerObj> sampler = (input_sampler_ == nullptr) ? nullptr : input_sampler_->SamplerCopy();
66 if (dataset_files_.empty()) {
67 node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_,
68 shuffle_mode_, cache_);
69 } else {
70 node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_,
71 shuffle_mode_, cache_);
72 }
73 node->SetSampleBytes(&sample_bytes_);
74 return node;
75 }
76
Print(std::ostream & out) const77 void MindDataNode::Print(std::ostream &out) const { out << (Name() + "(file:" + dataset_file_ + ",...)"); }
78
ValidateParams()79 Status MindDataNode::ValidateParams() {
80 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
81
82 RETURN_IF_NOT_OK(
83 ValidateEnum("MindDataset", "ShuffleMode", shuffle_mode_,
84 {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal, ShuffleMode::kInfile}));
85
86 std::vector<std::string> dataset_file_vec =
87 search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_;
88 RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataset", dataset_file_vec));
89
90 RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataset", input_sampler_));
91
92 if (!columns_list_.empty()) {
93 RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataset", "columns_list", columns_list_));
94 }
95
96 if (padded_sample_ != nullptr) {
97 if (num_padded_ < 0 || num_padded_ > INT_MAX) {
98 std::string err_msg =
99 "MindDataset: 'num_padded' must to be between 0 and INT32_MAX, but got: " + std::to_string(num_padded_);
100 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
101 }
102 if (columns_list_.empty()) {
103 std::string err_msg = "MindDataset: 'padded_sample' is specified and requires 'columns_list' as well";
104 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
105 }
106 for (std::string &column : columns_list_) {
107 if (padded_sample_.find(column) == padded_sample_.end()) {
108 std::string err_msg =
109 "MindDataset: " + column + " in 'columns_list' does not match any column in 'padded_sample'";
110 MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_;
111 return Status(StatusCode::kMDSyntaxError, err_msg);
112 }
113 }
114 }
115 if (num_padded_ > 0) {
116 if (padded_sample_ == nullptr) {
117 std::string err_msg = "MindDataset: num_padded is specified but padded_sample is not";
118 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
119 }
120 }
121
122 return Status::OK();
123 }
124
125 // Helper function to create runtime sampler for minddata dataset
BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> & sampler,std::vector<std::shared_ptr<mindrecord::ShardOperator>> * operators_,int64_t num_padded,ShuffleMode shuffle_mode)126 Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
127 std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
128 int64_t num_padded, ShuffleMode shuffle_mode) {
129 std::shared_ptr<mindrecord::ShardOperator> op = sampler->BuildForMindDataset();
130 if (op == nullptr) {
131 std::string err_msg =
132 "MindDataset: Unsupported sampler is supplied for MindDataset. Supported sampler list: "
133 "SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler";
134 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
135 }
136 std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
137 while (op != nullptr) {
138 // update the shuffle mode for sampler op or shuffle op
139 if (shuffle_mode != ShuffleMode::kFalse) {
140 op->UpdateShuffleMode(shuffle_mode);
141 }
142 if (op->GetNumSamples() != 0 &&
143 (op->GetShuffleMode() == ShuffleMode::kFiles || op->GetShuffleMode() == ShuffleMode::kInfile)) {
144 std::string err_msg =
145 "MindDataset: Shuffle.kFiles or Shuffle.kInfile and num_samples cannot be specified at the same time.";
146 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
147 }
148
149 auto distributed_sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
150 if (distributed_sampler_op && num_padded > 0) {
151 distributed_sampler_op->SetNumPaddedSamples(num_padded);
152 stack_ops.push(distributed_sampler_op);
153 } else {
154 stack_ops.push(op);
155 }
156 op = op->GetChildOp();
157 }
158 while (!stack_ops.empty()) {
159 operators_->push_back(stack_ops.top());
160 stack_ops.pop();
161 }
162 return Status::OK();
163 }
164
165 // Helper function to set sample_bytes from py::byte type
SetSampleBytes(std::map<std::string,std::string> * sample_bytes)166 void MindDataNode::SetSampleBytes(std::map<std::string, std::string> *sample_bytes) { sample_bytes_ = *sample_bytes; }
167
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)168 Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
169 RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(input_sampler_, &operators_, num_padded_, shuffle_mode_));
170
171 std::shared_ptr<SamplerRT> sampler_rt = nullptr;
172 // Build the sampler IR into a runtime sampler.
173 // This will also create a shard reader object, saved in this node's sampler_.
174 RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
175
176 // Now we need to acquire the newly created shard reader from this node's sampler_.
177 // There are two cases:
178 // 1. If this node is cached, now after cache transform pass, its sampler_ has already been replaced by cache lookup
179 // node, and we should find the shard reader from cache lookup node's sampler_.
180 // 2. If this node is not cached, just acquire the shard reader from this node's sampler_.
181 std::unique_ptr<ShardReader> shard_reader;
182 if (IsDescendantOfCache()) {
183 auto cache_lookup_sampler = std::dynamic_pointer_cast<CacheLookupNode>(sampler_);
184 CHECK_FAIL_RETURN_UNEXPECTED(cache_lookup_sampler != nullptr,
185 "Internal error. MindDataNode is cached, its sampler should be cache lookup node");
186 auto mr_sampler = std::dynamic_pointer_cast<MindRecordSamplerObj>(cache_lookup_sampler->Sampler());
187 CHECK_FAIL_RETURN_UNEXPECTED(mr_sampler != nullptr,
188 "Internal error. CacheLookupNode's sampler should be a MindRecordSamplerObj object");
189 RETURN_IF_NOT_OK(mr_sampler->GetShardReader(&shard_reader));
190 } else {
191 auto mindrecord_sampler = sampler_;
192 // In checkpoint recovery, a SkipFirstEpochSampler will be inserted, so MindRecordSampler will be its child
193 if (std::dynamic_pointer_cast<SkipFirstEpochSamplerObj>(mindrecord_sampler) != nullptr) {
194 mindrecord_sampler = mindrecord_sampler->GetChild()[0];
195 }
196 auto mr_sampler = std::dynamic_pointer_cast<MindRecordSamplerObj>(mindrecord_sampler);
197 CHECK_FAIL_RETURN_UNEXPECTED(mr_sampler != nullptr,
198 "Internal error. MindDataNode's sampler should be a MindRecordSamplerObj object");
199 RETURN_IF_NOT_OK(mr_sampler->GetShardReader(&shard_reader));
200 }
201
202 std::shared_ptr<MindRecordOp> mindrecord_op;
203 // If pass a string to MindData(), it will be treated as a pattern to search for matched files,
204 // else if pass a vector to MindData(), it will be treated as specified files to be read
205 if (search_for_pattern_) {
206 std::vector<std::string> dataset_file_vec_ = {dataset_file_};
207 mindrecord_op = std::make_shared<MindRecordOp>(
208 num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_, columns_list_, operators_, num_padded_,
209 padded_sample_, sample_bytes_, shuffle_mode_, std::move(shard_reader), std::move(sampler_rt));
210 } else {
211 mindrecord_op = std::make_shared<MindRecordOp>(
212 num_workers_, dataset_files_, search_for_pattern_, connector_que_size_, columns_list_, operators_, num_padded_,
213 padded_sample_, sample_bytes_, shuffle_mode_, std::move(shard_reader), std::move(sampler_rt));
214 }
215
216 RETURN_IF_NOT_OK(mindrecord_op->Init());
217 mindrecord_op->SetTotalRepeats(GetTotalRepeats());
218 mindrecord_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
219 node_ops->push_back(mindrecord_op);
220
221 return Status::OK();
222 }
223
224 // Get the shard id of node
GetShardId(int32_t * shard_id)225 Status MindDataNode::GetShardId(int32_t *shard_id) {
226 *shard_id = input_sampler_->ShardId();
227
228 return Status::OK();
229 }
230
231 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)232 Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
233 int64_t *dataset_size) {
234 if (dataset_size_ > 0) {
235 *dataset_size = dataset_size_;
236 return Status::OK();
237 }
238 int64_t num_rows = -1;
239 std::vector<std::shared_ptr<ShardOperator>> operators;
240 RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(input_sampler_, &operators, num_padded_, shuffle_mode_));
241
242 if (search_for_pattern_) {
243 dataset_files_ = {dataset_file_};
244 }
245
246 // The last operator is parent sampler
247 std::shared_ptr<ShardOperator> op = operators.back();
248 RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_));
249 *dataset_size = num_rows;
250 dataset_size_ = *dataset_size;
251 return Status::OK();
252 }
253
254 // Visitor accepting method for IRNodePass
Accept(IRNodePass * const p,bool * const modified)255 Status MindDataNode::Accept(IRNodePass *const p, bool *const modified) {
256 // Downcast shared pointer then call visitor
257 return p->Visit(shared_from_base<MindDataNode>(), modified);
258 }
259
260 // Visitor accepting method for IRNodePass
AcceptAfter(IRNodePass * const p,bool * const modified)261 Status MindDataNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
262 // Downcast shared pointer then call visitor
263 return p->VisitAfter(shared_from_base<MindDataNode>(), modified);
264 }
265 } // namespace dataset
266 } // namespace mindspore
267