1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/wiki_text_node.h"
18
19 #include "minddata/dataset/engine/datasetops/source/wiki_text_op.h"
20 #include "minddata/dataset/util/status.h"
21
22 namespace mindspore {
23 namespace dataset {
24 // Constructor for WikiTextNode.
WikiTextNode(const std::string & dataset_dir,const std::string & usage,int64_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)25 WikiTextNode::WikiTextNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
26 ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
27 const std::shared_ptr<DatasetCache> &cache)
28 : NonMappableSourceNode(std::move(cache)),
29 dataset_dir_(dataset_dir),
30 usage_(usage),
31 num_samples_(num_samples),
32 shuffle_(shuffle),
33 num_shards_(num_shards),
34 shard_id_(shard_id),
35 wikitext_files_list_(WalkAllFiles(usage, dataset_dir)) {
36 // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
37 // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
38 // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
39 // PreBuildSampler is phased out, this can be cleaned up.
40 GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
41 }
42
Copy()43 std::shared_ptr<DatasetNode> WikiTextNode::Copy() {
44 auto node =
45 std::make_shared<WikiTextNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
46 (void)node->SetNumWorkers(num_workers_);
47 (void)node->SetConnectorQueueSize(connector_que_size_);
48 return node;
49 }
50
Print(std::ostream & out) const51 void WikiTextNode::Print(std::ostream &out) const {
52 out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
53 ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
54 }
55
ValidateParams()56 Status WikiTextNode::ValidateParams() {
57 RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
58 RETURN_IF_NOT_OK(ValidateDatasetDirParam("WikiTextDataset", dataset_dir_));
59 RETURN_IF_NOT_OK(ValidateStringValue("WikiTextDataset", usage_, {"train", "test", "valid", "all"}));
60 RETURN_IF_NOT_OK(ValidateScalar("WikiTextDataset", "num_samples", num_samples_, {0}, false));
61 RETURN_IF_NOT_OK(ValidateEnum("WikiTextDataset", "ShuffleMode", shuffle_,
62 {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
63 RETURN_IF_NOT_OK(ValidateDatasetShardParams("WikiTextDataset", num_shards_, shard_id_));
64 return Status::OK();
65 }
66
67 // Function to build WikiTextNode.
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)68 Status WikiTextNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
69 bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
70 // Sort the dataset files in a lexicographical order.
71 std::vector<std::string> sorted_dataset_files = wikitext_files_list_;
72 std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
73 // Do internal Schema generation.
74 auto schema = std::make_unique<DataSchema>();
75 RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
76 // Create and initialize WikiTextNode.
77 std::shared_ptr<WikiTextOp> wikitext_op =
78 std::make_shared<WikiTextOp>(num_workers_, num_samples_, worker_connector_size_, std::move(schema),
79 sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
80 RETURN_IF_NOT_OK(wikitext_op->Init());
81 // If a global shuffle is used for WikiText, it will inject a shuffle op over the WikiText.
82 // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
83 // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset WikiText's shuffle
84 // option to false.
85 if (shuffle_ == ShuffleMode::kGlobal) {
86 // Inject ShuffleOp.
87 std::shared_ptr<ShuffleOp> shuffle_op = nullptr;
88 int64_t num_rows = 0;
89 // First, get the number of rows in the dataset.
90 RETURN_IF_NOT_OK(WikiTextOp::CountAllFileRows(wikitext_files_list_, &num_rows));
91 // Add the shuffle op after this op.
92 RETURN_IF_NOT_OK(
93 AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
94 shuffle_op->SetTotalRepeats(GetTotalRepeats());
95 shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
96 shuffle_op->Skip(skip_steps_);
97 node_ops->push_back(shuffle_op);
98 }
99 wikitext_op->SetTotalRepeats(GetTotalRepeats());
100 wikitext_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
101 // Add WikiTextNode.
102 node_ops->push_back(wikitext_op);
103 return Status::OK();
104 }
105
106 // Get the shard id of node.
GetShardId(int32_t * shard_id)107 Status WikiTextNode::GetShardId(int32_t *shard_id) {
108 *shard_id = shard_id_;
109 return Status::OK();
110 }
111
112 // Get Dataset size.
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)113 Status WikiTextNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
114 int64_t *dataset_size) {
115 if (dataset_size_ > 0) {
116 *dataset_size = dataset_size_;
117 return Status::OK();
118 }
119 int64_t num_rows, sample_size = num_samples_;
120 RETURN_IF_NOT_OK(WikiTextOp::CountAllFileRows(wikitext_files_list_, &num_rows));
121 num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
122 *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
123 dataset_size_ = *dataset_size;
124 return Status::OK();
125 }
126
to_json(nlohmann::json * out_json)127 Status WikiTextNode::to_json(nlohmann::json *out_json) {
128 nlohmann::json args;
129 args["num_parallel_workers"] = num_workers_;
130 args["connector_queue_size"] = connector_que_size_;
131 args["dataset_dir"] = dataset_dir_;
132 args["usage"] = usage_;
133 args["num_samples"] = num_samples_;
134 args["shuffle"] = shuffle_;
135 args["num_shards"] = num_shards_;
136 args["shard_id"] = shard_id_;
137 if (cache_ != nullptr) {
138 nlohmann::json cache_args;
139 RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
140 args["cache"] = cache_args;
141 }
142 *out_json = args;
143 return Status::OK();
144 }
145
146 // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
147 // WikiText by itself is a non-mappable dataset that does not support sampling.
148 // However, if a cache operator is injected at some other place higher in the tree, that cache can
149 // inherit this sampler from the leaf, providing sampling support from the caching layer.
150 // That is why we setup the sampler for a leaf node that does not use sampling.
SetupSamplerForCache(std::shared_ptr<SamplerObj> * sampler)151 Status WikiTextNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
152 bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
153 *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
154 return Status::OK();
155 }
156
157 // If a cache has been added into the ascendant tree over this WikiText node, then the cache will be executing
158 // a sampler for fetching the data. As such, any options in the WikiText node need to be reset to its defaults so
159 // that this WikiText node will produce the full set of data into the cache.
MakeSimpleProducer()160 Status WikiTextNode::MakeSimpleProducer() {
161 shard_id_ = 0;
162 num_shards_ = 1;
163 shuffle_ = ShuffleMode::kFalse;
164 num_samples_ = 0;
165 return Status::OK();
166 }
167
WalkAllFiles(const std::string & usage,const std::string & dataset_dir)168 std::vector<std::string> WikiTextNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
169 std::vector<std::string> wikitext_files_list;
170 Path train_prefix("wiki.train.tokens");
171 Path test_prefix("wiki.test.tokens");
172 Path valid_prefix("wiki.valid.tokens");
173 Path dir(dataset_dir);
174
175 if (usage == "train") {
176 Path temp_path = dir / train_prefix;
177 wikitext_files_list.push_back(temp_path.ToString());
178 } else if (usage == "test") {
179 Path temp_path = dir / test_prefix;
180 wikitext_files_list.push_back(temp_path.ToString());
181 } else if (usage == "valid") {
182 Path temp_path = dir / valid_prefix;
183 wikitext_files_list.push_back(temp_path.ToString());
184 } else {
185 Path temp_path = dir / train_prefix;
186 wikitext_files_list.push_back(temp_path.ToString());
187 Path temp_path1 = dir / test_prefix;
188 wikitext_files_list.push_back(temp_path1.ToString());
189 Path temp_path2 = dir / valid_prefix;
190 wikitext_files_list.push_back(temp_path2.ToString());
191 }
192 return wikitext_files_list;
193 }
194 } // namespace dataset
195 } // namespace mindspore
196