• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h"
17 
18 namespace mindspore {
19 namespace dataset {
20 // Constructor for AmazonReviewNode
AmazonReviewNode(const std::string & dataset_dir,const std::string & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)21 AmazonReviewNode::AmazonReviewNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
22                                    ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
23                                    const std::shared_ptr<DatasetCache> &cache)
24     : NonMappableSourceNode(std::move(cache)),
25       dataset_dir_(dataset_dir),
26       num_samples_(num_samples),
27       shuffle_(shuffle),
28       num_shards_(num_shards),
29       shard_id_(shard_id),
30       usage_(usage),
31       amazon_review_files_list_(WalkAllFiles(usage, dataset_dir)) {
32   // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass.
33   // User discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work
34   // if the num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to
35   // return num_shards. Once PreBuildSampler is phased out, this can be cleaned up.
36   GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
37 }
38 
Copy()39 std::shared_ptr<DatasetNode> AmazonReviewNode::Copy() {
40   auto node =
41     std::make_shared<AmazonReviewNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
42   (void)node->SetNumWorkers(num_workers_);
43   (void)node->SetConnectorQueueSize(connector_que_size_);
44   return node;
45 }
46 
Print(std::ostream & out) const47 void AmazonReviewNode::Print(std::ostream &out) const {
48   out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
49           ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
50 }
51 
ValidateParams()52 Status AmazonReviewNode::ValidateParams() {
53   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
54   RETURN_IF_NOT_OK(ValidateDatasetDirParam("AmazonReviewDataset", dataset_dir_));
55   RETURN_IF_NOT_OK(ValidateStringValue("AmazonReviewDataset", usage_, {"train", "test", "all"}));
56   RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AmazonReviewDataset", amazon_review_files_list_));
57   RETURN_IF_NOT_OK(ValidateScalar("AmazonReviewDataset", "num_samples", num_samples_, {0}, false));
58   RETURN_IF_NOT_OK(ValidateEnum("AmazonReviewDataset", "ShuffleMode", shuffle_,
59                                 {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
60 
61   RETURN_IF_NOT_OK(ValidateDatasetShardParams("AmazonReviewDataset", num_shards_, shard_id_));
62   return Status::OK();
63 }
64 
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)65 Status AmazonReviewNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
66   bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
67 
68   // Sort the dataset files in a lexicographical order.
69   std::vector<std::string> sorted_dataset_files = amazon_review_files_list_;
70   std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
71 
72   std::vector<std::shared_ptr<AmazonReviewOp::BaseRecord>> column_default;
73   column_default.push_back(std::make_shared<AmazonReviewOp::Record<std::string>>(AmazonReviewOp::STRING, ""));
74   column_default.push_back(std::make_shared<AmazonReviewOp::Record<std::string>>(AmazonReviewOp::STRING, ""));
75   column_default.push_back(std::make_shared<AmazonReviewOp::Record<std::string>>(AmazonReviewOp::STRING, ""));
76 
77   std::vector<std::string> column_name = {"label", "title", "content"};
78   char field_delim = ',';
79   std::shared_ptr<AmazonReviewOp> amazon_review_op = std::make_shared<AmazonReviewOp>(
80     num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_,
81     field_delim, column_default, column_name, sorted_dataset_files);
82   RETURN_IF_NOT_OK(amazon_review_op->Init());
83 
84   // If a global shuffle is used for AmazonReview, it will inject a shuffle op over the AmazonReview.
85   // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be
86   // built.This is achieved in the cache transform pass where we call MakeSimpleProducer to reset AmazonReview's
87   // shuffle option to false.
88   if (shuffle_ == ShuffleMode::kGlobal) {
89     // Inject ShuffleOp.
90     std::shared_ptr<ShuffleOp> shuffle_op = nullptr;
91     int64_t num_rows = 0;
92 
93     // First, get the number of rows in the dataset.
94     RETURN_IF_NOT_OK(AmazonReviewOp::CountAllFileRows(sorted_dataset_files, false, &num_rows));
95     // Add the shuffle op after this op.
96     RETURN_IF_NOT_OK(
97       AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
98     shuffle_op->SetTotalRepeats(GetTotalRepeats());
99     shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
100     shuffle_op->Skip(skip_steps_);
101     node_ops->push_back(shuffle_op);
102   }
103   amazon_review_op->SetTotalRepeats(GetTotalRepeats());
104   amazon_review_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
105   node_ops->push_back(amazon_review_op);
106   return Status::OK();
107 }
108 
GetShardId(int32_t * shard_id)109 Status AmazonReviewNode::GetShardId(int32_t *shard_id) {
110   *shard_id = shard_id_;
111   return Status::OK();
112 }
113 
114 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)115 Status AmazonReviewNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
116                                         int64_t *dataset_size) {
117   if (dataset_size_ > 0) {
118     *dataset_size = dataset_size_;
119     return Status::OK();
120   }
121 
122   int64_t num_rows, sample_size;
123   RETURN_IF_NOT_OK(AmazonReviewOp::CountAllFileRows(amazon_review_files_list_, false, &num_rows));
124   sample_size = num_samples_;
125   num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
126   *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
127   dataset_size_ = *dataset_size;
128   return Status::OK();
129 }
130 
to_json(nlohmann::json * out_json)131 Status AmazonReviewNode::to_json(nlohmann::json *out_json) {
132   nlohmann::json args;
133   args["num_parallel_workers"] = num_workers_;
134   args["connector_queue_size"] = connector_que_size_;
135   args["dataset_dir"] = dataset_dir_;
136   args["usage"] = usage_;
137   args["num_samples"] = num_samples_;
138   args["shuffle"] = shuffle_;
139   args["num_shards"] = num_shards_;
140   args["shard_id"] = shard_id_;
141   if (cache_ != nullptr) {
142     nlohmann::json cache_args;
143     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
144     args["cache"] = cache_args;
145   }
146   *out_json = args;
147   return Status::OK();
148 }
149 
150 // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent
151 // class. AmazonReview by itself is a non-mappable dataset that does not support sampling. However, if a cache
152 // operator is injected at some other place higher in the tree, that cache can inherit this sampler from the leaf,
153 // providing sampling support from the caching layer. That is why we setup the sampler for a leaf node that does not
154 // use sampling.
SetupSamplerForCache(std::shared_ptr<SamplerObj> * sampler)155 Status AmazonReviewNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
156   bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
157   *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
158   return Status::OK();
159 }
160 
161 // If a cache has been added into the ascendant tree over this AmazonReview node, then the cache will be executing
162 // a sampler for fetching the data. As such, any options in the AmazonReview node need to be reset to its defaults so
163 // If a cache has been added into the ascendant tree over this AmazonReview node, then the cache will be executing
MakeSimpleProducer()164 Status AmazonReviewNode::MakeSimpleProducer() {
165   shard_id_ = 0;
166   num_shards_ = 1;
167   shuffle_ = ShuffleMode::kFalse;
168   num_samples_ = 0;
169   return Status::OK();
170 }
171 
WalkAllFiles(const std::string & usage,const std::string & dataset_dir)172 std::vector<std::string> AmazonReviewNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
173   std::vector<std::string> amazon_review_files_list;
174   Path train_prefix("train.csv");
175   Path test_prefix("test.csv");
176   Path dir(dataset_dir);
177 
178   if (usage == "train") {
179     Path temp_path = dir / train_prefix;
180     amazon_review_files_list.push_back(temp_path.ToString());
181   } else if (usage == "test") {
182     Path temp_path = dir / test_prefix;
183     amazon_review_files_list.push_back(temp_path.ToString());
184   } else {
185     Path temp_path = dir / train_prefix;
186     amazon_review_files_list.push_back(temp_path.ToString());
187     Path temp_path1 = dir / test_prefix;
188     amazon_review_files_list.push_back(temp_path1.ToString());
189   }
190   return amazon_review_files_list;
191 }
192 }  // namespace dataset
193 }  // namespace mindspore
194