• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
18 
19 #include <stack>
20 #include <utility>
21 
22 #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
24 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
25 #include "minddata/dataset/engine/ir/datasetops/source/samplers/mindrecord_sampler_ir.h"
26 #include "minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h"
27 #include "minddata/dataset/engine/opt/pass.h"
28 #include "minddata/dataset/util/status.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 
MindDataNode(const std::vector<std::string> & dataset_files,const std::vector<std::string> & columns_list,const std::shared_ptr<SamplerObj> & sampler,nlohmann::json padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,std::shared_ptr<DatasetCache> cache)33 MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
34                            const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded,
35                            ShuffleMode shuffle_mode, std::shared_ptr<DatasetCache> cache)
36     : MappableSourceNode(std::move(cache)),
37       dataset_file_(std::string()),
38       dataset_files_(dataset_files),
39       search_for_pattern_(false),
40       columns_list_(columns_list),
41       input_sampler_(sampler),
42       sampler_(std::make_shared<MindRecordSamplerObj>()),
43       padded_sample_(padded_sample),
44       sample_bytes_({}),
45       num_padded_(num_padded),
46       shuffle_mode_(shuffle_mode) {}
47 
MindDataNode(const std::string & dataset_file,const std::vector<std::string> & columns_list,const std::shared_ptr<SamplerObj> & sampler,nlohmann::json padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,std::shared_ptr<DatasetCache> cache)48 MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
49                            const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded,
50                            ShuffleMode shuffle_mode, std::shared_ptr<DatasetCache> cache)
51     : MappableSourceNode(std::move(cache)),
52       dataset_file_(dataset_file),
53       dataset_files_({}),
54       search_for_pattern_(true),
55       columns_list_(columns_list),
56       input_sampler_(sampler),
57       sampler_(std::make_shared<MindRecordSamplerObj>()),
58       padded_sample_(padded_sample),
59       sample_bytes_({}),
60       num_padded_(num_padded),
61       shuffle_mode_(shuffle_mode) {}
62 
Copy()63 std::shared_ptr<DatasetNode> MindDataNode::Copy() {
64   std::shared_ptr<MindDataNode> node;
65   std::shared_ptr<SamplerObj> sampler = (input_sampler_ == nullptr) ? nullptr : input_sampler_->SamplerCopy();
66   if (dataset_files_.empty()) {
67     node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_,
68                                           shuffle_mode_, cache_);
69   } else {
70     node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_,
71                                           shuffle_mode_, cache_);
72   }
73   node->SetSampleBytes(&sample_bytes_);
74   return node;
75 }
76 
Print(std::ostream & out) const77 void MindDataNode::Print(std::ostream &out) const { out << (Name() + "(file:" + dataset_file_ + ",...)"); }
78 
ValidateParams()79 Status MindDataNode::ValidateParams() {
80   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
81 
82   RETURN_IF_NOT_OK(
83     ValidateEnum("MindDataset", "ShuffleMode", shuffle_mode_,
84                  {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal, ShuffleMode::kInfile}));
85 
86   std::vector<std::string> dataset_file_vec =
87     search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_;
88   RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataset", dataset_file_vec));
89 
90   RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataset", input_sampler_));
91 
92   if (!columns_list_.empty()) {
93     RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataset", "columns_list", columns_list_));
94   }
95 
96   if (padded_sample_ != nullptr) {
97     if (num_padded_ < 0 || num_padded_ > INT_MAX) {
98       std::string err_msg =
99         "MindDataset: 'num_padded' must to be between 0 and INT32_MAX, but got: " + std::to_string(num_padded_);
100       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
101     }
102     if (columns_list_.empty()) {
103       std::string err_msg = "MindDataset: 'padded_sample' is specified and requires 'columns_list' as well";
104       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
105     }
106     for (std::string &column : columns_list_) {
107       if (padded_sample_.find(column) == padded_sample_.end()) {
108         std::string err_msg =
109           "MindDataset: " + column + " in 'columns_list' does not match any column in 'padded_sample'";
110         MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_;
111         return Status(StatusCode::kMDSyntaxError, err_msg);
112       }
113     }
114   }
115   if (num_padded_ > 0) {
116     if (padded_sample_ == nullptr) {
117       std::string err_msg = "MindDataset: num_padded is specified but padded_sample is not";
118       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
119     }
120   }
121 
122   return Status::OK();
123 }
124 
125 // Helper function to create runtime sampler for minddata dataset
BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> & sampler,std::vector<std::shared_ptr<mindrecord::ShardOperator>> * operators_,int64_t num_padded,ShuffleMode shuffle_mode)126 Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
127                                                   std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
128                                                   int64_t num_padded, ShuffleMode shuffle_mode) {
129   std::shared_ptr<mindrecord::ShardOperator> op = sampler->BuildForMindDataset();
130   if (op == nullptr) {
131     std::string err_msg =
132       "MindDataset: Unsupported sampler is supplied for MindDataset. Supported sampler list: "
133       "SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler";
134     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
135   }
136   std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
137   while (op != nullptr) {
138     // update the shuffle mode for sampler op or shuffle op
139     if (shuffle_mode != ShuffleMode::kFalse) {
140       op->UpdateShuffleMode(shuffle_mode);
141     }
142     if (op->GetNumSamples() != 0 &&
143         (op->GetShuffleMode() == ShuffleMode::kFiles || op->GetShuffleMode() == ShuffleMode::kInfile)) {
144       std::string err_msg =
145         "MindDataset: Shuffle.kFiles or Shuffle.kInfile and num_samples cannot be specified at the same time.";
146       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
147     }
148 
149     auto distributed_sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
150     if (distributed_sampler_op && num_padded > 0) {
151       distributed_sampler_op->SetNumPaddedSamples(num_padded);
152       stack_ops.push(distributed_sampler_op);
153     } else {
154       stack_ops.push(op);
155     }
156     op = op->GetChildOp();
157   }
158   while (!stack_ops.empty()) {
159     operators_->push_back(stack_ops.top());
160     stack_ops.pop();
161   }
162   return Status::OK();
163 }
164 
165 // Helper function to set sample_bytes from py::byte type
SetSampleBytes(std::map<std::string,std::string> * sample_bytes)166 void MindDataNode::SetSampleBytes(std::map<std::string, std::string> *sample_bytes) { sample_bytes_ = *sample_bytes; }
167 
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)168 Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
169   RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(input_sampler_, &operators_, num_padded_, shuffle_mode_));
170 
171   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
172   // Build the sampler IR into a runtime sampler.
173   // This will also create a shard reader object, saved in this node's sampler_.
174   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
175 
176   // Now we need to acquire the newly created shard reader from this node's sampler_.
177   // There are two cases:
178   // 1. If this node is cached, now after cache transform pass, its sampler_ has already been replaced by cache lookup
179   // node, and we should find the shard reader from cache lookup node's sampler_.
180   // 2. If this node is not cached, just acquire the shard reader from this node's sampler_.
181   std::unique_ptr<ShardReader> shard_reader;
182   if (IsDescendantOfCache()) {
183     auto cache_lookup_sampler = std::dynamic_pointer_cast<CacheLookupNode>(sampler_);
184     CHECK_FAIL_RETURN_UNEXPECTED(cache_lookup_sampler != nullptr,
185                                  "Internal error. MindDataNode is cached, its sampler should be cache lookup node");
186     auto mr_sampler = std::dynamic_pointer_cast<MindRecordSamplerObj>(cache_lookup_sampler->Sampler());
187     CHECK_FAIL_RETURN_UNEXPECTED(mr_sampler != nullptr,
188                                  "Internal error. CacheLookupNode's sampler should be a MindRecordSamplerObj object");
189     RETURN_IF_NOT_OK(mr_sampler->GetShardReader(&shard_reader));
190   } else {
191     auto mindrecord_sampler = sampler_;
192     // In checkpoint recovery, a SkipFirstEpochSampler will be inserted, so MindRecordSampler will be its child
193     if (std::dynamic_pointer_cast<SkipFirstEpochSamplerObj>(mindrecord_sampler) != nullptr) {
194       mindrecord_sampler = mindrecord_sampler->GetChild()[0];
195     }
196     auto mr_sampler = std::dynamic_pointer_cast<MindRecordSamplerObj>(mindrecord_sampler);
197     CHECK_FAIL_RETURN_UNEXPECTED(mr_sampler != nullptr,
198                                  "Internal error. MindDataNode's sampler should be a MindRecordSamplerObj object");
199     RETURN_IF_NOT_OK(mr_sampler->GetShardReader(&shard_reader));
200   }
201 
202   std::shared_ptr<MindRecordOp> mindrecord_op;
203   // If pass a string to MindData(), it will be treated as a pattern to search for matched files,
204   // else if pass a vector to MindData(), it will be treated as specified files to be read
205   if (search_for_pattern_) {
206     std::vector<std::string> dataset_file_vec_ = {dataset_file_};
207     mindrecord_op = std::make_shared<MindRecordOp>(
208       num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_, columns_list_, operators_, num_padded_,
209       padded_sample_, sample_bytes_, shuffle_mode_, std::move(shard_reader), std::move(sampler_rt));
210   } else {
211     mindrecord_op = std::make_shared<MindRecordOp>(
212       num_workers_, dataset_files_, search_for_pattern_, connector_que_size_, columns_list_, operators_, num_padded_,
213       padded_sample_, sample_bytes_, shuffle_mode_, std::move(shard_reader), std::move(sampler_rt));
214   }
215 
216   RETURN_IF_NOT_OK(mindrecord_op->Init());
217   mindrecord_op->SetTotalRepeats(GetTotalRepeats());
218   mindrecord_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
219   node_ops->push_back(mindrecord_op);
220 
221   return Status::OK();
222 }
223 
224 // Get the shard id of node
GetShardId(int32_t * shard_id)225 Status MindDataNode::GetShardId(int32_t *shard_id) {
226   *shard_id = input_sampler_->ShardId();
227 
228   return Status::OK();
229 }
230 
231 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)232 Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
233                                     int64_t *dataset_size) {
234   if (dataset_size_ > 0) {
235     *dataset_size = dataset_size_;
236     return Status::OK();
237   }
238   int64_t num_rows = -1;
239   std::vector<std::shared_ptr<ShardOperator>> operators;
240   RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(input_sampler_, &operators, num_padded_, shuffle_mode_));
241 
242   if (search_for_pattern_) {
243     dataset_files_ = {dataset_file_};
244   }
245 
246   // The last operator is parent sampler
247   std::shared_ptr<ShardOperator> op = operators.back();
248   RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_));
249   *dataset_size = num_rows;
250   dataset_size_ = *dataset_size;
251   return Status::OK();
252 }
253 
254 // Visitor accepting method for IRNodePass
Accept(IRNodePass * const p,bool * const modified)255 Status MindDataNode::Accept(IRNodePass *const p, bool *const modified) {
256   // Downcast shared pointer then call visitor
257   return p->Visit(shared_from_base<MindDataNode>(), modified);
258 }
259 
260 // Visitor accepting method for IRNodePass
AcceptAfter(IRNodePass * const p,bool * const modified)261 Status MindDataNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
262   // Downcast shared pointer then call visitor
263   return p->VisitAfter(shared_from_base<MindDataNode>(), modified);
264 }
265 }  // namespace dataset
266 }  // namespace mindspore
267