• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
18 
19 #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
20 #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
21 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
22 #include "minddata/dataset/engine/ir/datasetops/source/samplers/mindrecord_sampler_ir.h"
23 #include "minddata/dataset/engine/opt/pass.h"
24 #include "minddata/dataset/util/status.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 
MindDataNode(const std::vector<std::string> & dataset_files,const std::vector<std::string> & columns_list,const std::shared_ptr<SamplerObj> & sampler,nlohmann::json padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,std::shared_ptr<DatasetCache> cache)29 MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
30                            const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded,
31                            ShuffleMode shuffle_mode, std::shared_ptr<DatasetCache> cache)
32     : MappableSourceNode(std::move(cache)),
33       dataset_file_(std::string()),
34       dataset_files_(dataset_files),
35       search_for_pattern_(false),
36       columns_list_(columns_list),
37       input_sampler_(sampler),
38       sampler_(std::make_shared<MindRecordSamplerObj>()),
39       padded_sample_(padded_sample),
40       sample_bytes_({}),
41       num_padded_(num_padded),
42       shuffle_mode_(shuffle_mode) {}
43 
MindDataNode(const std::string & dataset_file,const std::vector<std::string> & columns_list,const std::shared_ptr<SamplerObj> & sampler,nlohmann::json padded_sample,int64_t num_padded,ShuffleMode shuffle_mode,std::shared_ptr<DatasetCache> cache)44 MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
45                            const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded,
46                            ShuffleMode shuffle_mode, std::shared_ptr<DatasetCache> cache)
47     : MappableSourceNode(std::move(cache)),
48       dataset_file_(dataset_file),
49       dataset_files_({}),
50       search_for_pattern_(true),
51       columns_list_(columns_list),
52       input_sampler_(sampler),
53       sampler_(std::make_shared<MindRecordSamplerObj>()),
54       padded_sample_(padded_sample),
55       sample_bytes_({}),
56       num_padded_(num_padded),
57       shuffle_mode_(shuffle_mode) {}
58 
Copy()59 std::shared_ptr<DatasetNode> MindDataNode::Copy() {
60   std::shared_ptr<MindDataNode> node;
61   std::shared_ptr<SamplerObj> sampler = (input_sampler_ == nullptr) ? nullptr : input_sampler_->SamplerCopy();
62   if (dataset_files_.empty()) {
63     node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_,
64                                           shuffle_mode_, cache_);
65   } else {
66     node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_,
67                                           shuffle_mode_, cache_);
68   }
69   node->SetSampleBytes(&sample_bytes_);
70   return node;
71 }
72 
Print(std::ostream & out) const73 void MindDataNode::Print(std::ostream &out) const { out << (Name() + "(file:" + dataset_file_ + ",...)"); }
74 
ValidateParams()75 Status MindDataNode::ValidateParams() {
76   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
77   constexpr size_t max_len = 4096;
78   if (!search_for_pattern_ && dataset_files_.size() > max_len) {
79     std::string err_msg =
80       "MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " +
81       std::to_string(dataset_file_.size());
82     MS_LOG(ERROR) << err_msg;
83     RETURN_STATUS_SYNTAX_ERROR(err_msg);
84   }
85 
86   if (shuffle_mode_ != ShuffleMode::kFalse && shuffle_mode_ != ShuffleMode::kFiles &&
87       shuffle_mode_ != ShuffleMode::kGlobal && shuffle_mode_ != ShuffleMode::kInfile) {
88     std::string err_msg = "TFRecordNode: Invalid ShuffleMode, check input value of enum.";
89     MS_LOG(ERROR) << err_msg;
90     RETURN_STATUS_SYNTAX_ERROR(err_msg);
91   }
92 
93   std::vector<std::string> dataset_file_vec =
94     search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_;
95   RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataNode", dataset_file_vec));
96 
97   RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataNode", input_sampler_));
98 
99   if (!columns_list_.empty()) {
100     RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataNode", "columns_list", columns_list_));
101   }
102 
103   if (padded_sample_ != nullptr) {
104     if (num_padded_ < 0 || num_padded_ > INT_MAX) {
105       std::string err_msg =
106         "MindDataNode: num_padded must to be between 0 and INT32_MAX, but got: " + std::to_string(num_padded_);
107       MS_LOG(ERROR) << err_msg;
108       RETURN_STATUS_SYNTAX_ERROR(err_msg);
109     }
110     if (columns_list_.empty()) {
111       std::string err_msg = "MindDataNode: padded_sample is specified and requires columns_list as well";
112       MS_LOG(ERROR) << err_msg;
113       RETURN_STATUS_SYNTAX_ERROR(err_msg);
114     }
115     for (std::string &column : columns_list_) {
116       if (padded_sample_.find(column) == padded_sample_.end()) {
117         std::string err_msg = "MindDataNode: " + column + " in columns_list does not match any column in padded_sample";
118         MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_;
119         RETURN_STATUS_SYNTAX_ERROR(err_msg);
120       }
121     }
122   }
123   if (num_padded_ > 0) {
124     if (padded_sample_ == nullptr) {
125       std::string err_msg = "MindDataNode: num_padded is specified but padded_sample is not";
126       MS_LOG(ERROR) << err_msg;
127       RETURN_STATUS_SYNTAX_ERROR(err_msg);
128     }
129   }
130 
131   return Status::OK();
132 }
133 
134 // Helper function to create runtime sampler for minddata dataset
BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> & sampler,std::vector<std::shared_ptr<mindrecord::ShardOperator>> * operators_,int64_t num_padded,ShuffleMode shuffle_mode)135 Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
136                                                   std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
137                                                   int64_t num_padded, ShuffleMode shuffle_mode) {
138   std::shared_ptr<mindrecord::ShardOperator> op = sampler->BuildForMindDataset();
139   if (op == nullptr) {
140     std::string err_msg =
141       "MindDataNode: Unsupported sampler is supplied for MindDataset. Supported sampler list: "
142       "SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler";
143     MS_LOG(ERROR) << err_msg;
144     RETURN_STATUS_SYNTAX_ERROR(err_msg);
145   }
146   std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
147   while (op != nullptr) {
148     // update the shuffle mode for sampler op or shuffle op
149     if (shuffle_mode != ShuffleMode::kFalse) {
150       op->UpdateShuffleMode(shuffle_mode);
151     }
152 
153     auto distributed_sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
154     if (distributed_sampler_op && num_padded > 0) {
155       distributed_sampler_op->SetNumPaddedSamples(num_padded);
156       stack_ops.push(distributed_sampler_op);
157     } else {
158       stack_ops.push(op);
159     }
160     op = op->GetChildOp();
161   }
162   while (!stack_ops.empty()) {
163     operators_->push_back(stack_ops.top());
164     stack_ops.pop();
165   }
166   return Status::OK();
167 }
168 
169 // Helper function to set sample_bytes from py::byte type
SetSampleBytes(std::map<std::string,std::string> * sample_bytes)170 void MindDataNode::SetSampleBytes(std::map<std::string, std::string> *sample_bytes) { sample_bytes_ = *sample_bytes; }
171 
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)172 Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
173   RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(input_sampler_, &operators_, num_padded_, shuffle_mode_));
174 
175   std::shared_ptr<SamplerRT> sampler_rt = nullptr;
176   // Build the sampler IR into a runtime sampler.
177   // This will also create a shard reader object, saved in this node's sampler_.
178   RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
179 
180   // Now we need to acquire the newly created shard reader from this node's sampler_.
181   // There are two cases:
182   // 1. If this node is cached, now after cache transform pass, its sampler_ has already been replaced by cache lookup
183   // node, and we should find the shard reader from cache lookup node's sampler_.
184   // 2. If this node is not cached, just acquire the shard reader from this node's sampler_.
185   std::unique_ptr<ShardReader> shard_reader;
186   if (IsDescendantOfCache()) {
187     auto cache_lookup_sampler = std::dynamic_pointer_cast<CacheLookupNode>(sampler_);
188     CHECK_FAIL_RETURN_UNEXPECTED(cache_lookup_sampler != nullptr,
189                                  "Internal error. MindDataNode is cached, its sampler should be cache lookup node");
190     auto mr_sampler = std::dynamic_pointer_cast<MindRecordSamplerObj>(cache_lookup_sampler->Sampler());
191     CHECK_FAIL_RETURN_UNEXPECTED(mr_sampler != nullptr,
192                                  "Internal error. CacheLookupNode's sampler should be a MindRecordSamplerObj object");
193     RETURN_IF_NOT_OK(mr_sampler->GetShardReader(&shard_reader));
194   } else {
195     auto mr_sampler = std::dynamic_pointer_cast<MindRecordSamplerObj>(sampler_);
196     CHECK_FAIL_RETURN_UNEXPECTED(mr_sampler != nullptr,
197                                  "Internal error. MindDataNode's sampler should be a MindRecordSamplerObj object");
198     RETURN_IF_NOT_OK(mr_sampler->GetShardReader(&shard_reader));
199   }
200 
201   std::shared_ptr<MindRecordOp> mindrecord_op;
202   // If pass a string to MindData(), it will be treated as a pattern to search for matched files,
203   // else if pass a vector to MindData(), it will be treated as specified files to be read
204   if (search_for_pattern_) {
205     std::vector<std::string> dataset_file_vec_ = {dataset_file_};
206     mindrecord_op = std::make_shared<MindRecordOp>(
207       num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_, columns_list_, operators_, num_padded_,
208       padded_sample_, sample_bytes_, shuffle_mode_, std::move(shard_reader), std::move(sampler_rt));
209   } else {
210     mindrecord_op = std::make_shared<MindRecordOp>(
211       num_workers_, dataset_files_, search_for_pattern_, connector_que_size_, columns_list_, operators_, num_padded_,
212       padded_sample_, sample_bytes_, shuffle_mode_, std::move(shard_reader), std::move(sampler_rt));
213   }
214 
215   RETURN_IF_NOT_OK(mindrecord_op->Init());
216   mindrecord_op->SetTotalRepeats(GetTotalRepeats());
217   mindrecord_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
218   node_ops->push_back(mindrecord_op);
219 
220   return Status::OK();
221 }
222 
223 // Get the shard id of node
GetShardId(int32_t * shard_id)224 Status MindDataNode::GetShardId(int32_t *shard_id) {
225   *shard_id = input_sampler_->ShardId();
226 
227   return Status::OK();
228 }
229 
230 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)231 Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
232                                     int64_t *dataset_size) {
233   if (dataset_size_ > 0) {
234     *dataset_size = dataset_size_;
235     return Status::OK();
236   }
237   int64_t num_rows = -1;
238   std::vector<std::shared_ptr<ShardOperator>> operators;
239   RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(input_sampler_, &operators, num_padded_, shuffle_mode_));
240 
241   if (search_for_pattern_) {
242     dataset_files_ = {dataset_file_};
243   }
244 
245   // The last operator is parent sampler
246   std::shared_ptr<ShardOperator> op = operators.back();
247   RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_));
248   *dataset_size = num_rows;
249   dataset_size_ = *dataset_size;
250   return Status::OK();
251 }
252 
253 // Visitor accepting method for IRNodePass
Accept(IRNodePass * const p,bool * const modified)254 Status MindDataNode::Accept(IRNodePass *const p, bool *const modified) {
255   // Downcast shared pointer then call visitor
256   return p->Visit(shared_from_base<MindDataNode>(), modified);
257 }
258 
259 // Visitor accepting method for IRNodePass
AcceptAfter(IRNodePass * const p,bool * const modified)260 Status MindDataNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
261   // Downcast shared pointer then call visitor
262   return p->VisitAfter(shared_from_base<MindDataNode>(), modified);
263 }
264 }  // namespace dataset
265 }  // namespace mindspore
266