• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
18 
19 #include "minddata/dataset/engine/datasetops/source/multi30k_op.h"
20 
21 namespace mindspore {
22 namespace dataset {
Multi30kNode(const std::string & dataset_dir,const std::string & usage,const std::vector<std::string> & language_pair,int32_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,std::shared_ptr<DatasetCache> cache)23 Multi30kNode::Multi30kNode(const std::string &dataset_dir, const std::string &usage,
24                            const std::vector<std::string> &language_pair, int32_t num_samples, ShuffleMode shuffle,
25                            int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
26     : NonMappableSourceNode(std::move(cache)),
27       dataset_dir_(dataset_dir),
28       usage_(usage),
29       language_pair_(language_pair),
30       num_samples_(num_samples),
31       shuffle_(shuffle),
32       num_shards_(num_shards),
33       shard_id_(shard_id),
34       multi30k_files_list_(WalkAllFiles(usage, dataset_dir)) {
35   GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
36 }
37 
Print(std::ostream & out) const38 void Multi30kNode::Print(std::ostream &out) const {
39   out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
40           ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
41 }
42 
Copy()43 std::shared_ptr<DatasetNode> Multi30kNode::Copy() {
44   auto node = std::make_shared<Multi30kNode>(dataset_dir_, usage_, language_pair_, num_samples_, shuffle_, num_shards_,
45                                              shard_id_, cache_);
46   (void)node->SetNumWorkers(num_workers_);
47   (void)node->SetConnectorQueueSize(connector_que_size_);
48   return node;
49 }
50 
51 // Function to build Multi30kNode
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)52 Status Multi30kNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
53   bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
54 
55   std::vector<std::string> sorted_dataset_files = multi30k_files_list_;
56   std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
57 
58   auto schema = std::make_unique<DataSchema>();
59   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1)));
60   RETURN_IF_NOT_OK(
61     schema->AddColumn(ColDescriptor("translation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1)));
62 
63   std::shared_ptr<Multi30kOp> multi30k_op =
64     std::make_shared<Multi30kOp>(num_workers_, num_samples_, language_pair_, worker_connector_size_, std::move(schema),
65                                  sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
66   RETURN_IF_NOT_OK(multi30k_op->Init());
67 
68   if (shuffle_ == ShuffleMode::kGlobal) {
69     // Inject ShuffleOp
70     std::shared_ptr<ShuffleOp> shuffle_op = nullptr;
71     int64_t num_rows = 0;
72 
73     // First, get the number of rows in the dataset
74     RETURN_IF_NOT_OK(Multi30kOp::CountAllFileRows(sorted_dataset_files, &num_rows));
75 
76     // Add the shuffle op after this op
77     RETURN_IF_NOT_OK(
78       AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
79     shuffle_op->SetTotalRepeats(GetTotalRepeats());
80     shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
81     shuffle_op->Skip(skip_steps_);
82     node_ops->push_back(shuffle_op);
83   }
84   multi30k_op->SetTotalRepeats(GetTotalRepeats());
85   multi30k_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
86   // Add Multi30kOp
87   node_ops->push_back(multi30k_op);
88 
89   return Status::OK();
90 }
91 
ValidateParams()92 Status Multi30kNode::ValidateParams() {
93   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
94   RETURN_IF_NOT_OK(ValidateDatasetDirParam("Multi30kDataset", dataset_dir_));
95   RETURN_IF_NOT_OK(ValidateDatasetFilesParam("Multi30kDataset", multi30k_files_list_));
96   RETURN_IF_NOT_OK(ValidateStringValue("Multi30kDataset", usage_, {"train", "valid", "test", "all"}));
97   RETURN_IF_NOT_OK(ValidateEnum("Multi30kDataset", "ShuffleMode", shuffle_,
98                                 {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
99 
100   const int kLanguagePairSize = 2;
101   if (language_pair_.size() != kLanguagePairSize) {
102     std::string err_msg =
103       "Multi30kDataset: language_pair expecting size 2, but got: " + std::to_string(language_pair_.size());
104     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
105   }
106 
107   const std::vector<std::vector<std::string>> support_language_pair = {{"en", "de"}, {"de", "en"}};
108   if (language_pair_ != support_language_pair[0] && language_pair_ != support_language_pair[1]) {
109     std::string err_msg = R"(Multi30kDataset: language_pair must be {"en", "de"} or {"de", "en"}.)";
110     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
111   }
112 
113   RETURN_IF_NOT_OK(ValidateScalar("Multi30kDataset", "num_samples", num_samples_, {0}, false));
114   RETURN_IF_NOT_OK(ValidateDatasetShardParams("Multi30kDataset", num_shards_, shard_id_));
115   return Status::OK();
116 }
117 
GetShardId(int32_t * shard_id)118 Status Multi30kNode::GetShardId(int32_t *shard_id) {
119   *shard_id = shard_id_;
120   return Status::OK();
121 }
122 
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)123 Status Multi30kNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
124                                     int64_t *dataset_size) {
125   if (dataset_size_ > 0) {
126     *dataset_size = dataset_size_;
127     return Status::OK();
128   }
129   int64_t num_rows, sample_size = num_samples_;
130   RETURN_IF_NOT_OK(Multi30kOp::CountAllFileRows(multi30k_files_list_, &num_rows));
131   num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
132   *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
133   dataset_size_ = *dataset_size;
134   return Status::OK();
135 }
136 
to_json(nlohmann::json * out_json)137 Status Multi30kNode::to_json(nlohmann::json *out_json) {
138   nlohmann::json args;
139   args["num_parallel_workers"] = num_workers_;
140   args["connector_queue_size"] = connector_que_size_;
141   args["dataset_dir"] = dataset_dir_;
142   args["num_samples"] = num_samples_;
143   args["shuffle"] = shuffle_;
144   args["num_shards"] = num_shards_;
145   args["shard_id"] = shard_id_;
146   args["language_pair"] = language_pair_;
147   if (cache_ != nullptr) {
148     nlohmann::json cache_args;
149     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
150     args["cache"] = cache_args;
151   }
152   *out_json = args;
153   return Status::OK();
154 }
155 
SetupSamplerForCache(std::shared_ptr<SamplerObj> * sampler)156 Status Multi30kNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
157   bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
158   *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
159   return Status::OK();
160 }
161 
MakeSimpleProducer()162 Status Multi30kNode::MakeSimpleProducer() {
163   shard_id_ = 0;
164   num_shards_ = 1;
165   shuffle_ = ShuffleMode::kFalse;
166   num_samples_ = 0;
167   return Status::OK();
168 }
169 
WalkAllFiles(const std::string & usage,const std::string & dataset_dir)170 std::vector<std::string> Multi30kNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
171   std::vector<std::string> multi30k_files_list;
172   Path train_en("training/train.en");
173   Path test_en("mmt16_task1_test/test.en");
174   Path valid_en("validation/val.en");
175   Path dir(dataset_dir);
176 
177   if (usage == "train") {
178     Path temp_path = dir / train_en;
179     multi30k_files_list.push_back(temp_path.ToString());
180   } else if (usage == "test") {
181     Path temp_path = dir / test_en;
182     multi30k_files_list.push_back(temp_path.ToString());
183   } else if (usage == "valid") {
184     Path temp_path = dir / valid_en;
185     multi30k_files_list.push_back(temp_path.ToString());
186   } else {
187     Path temp_path = dir / train_en;
188     multi30k_files_list.push_back(temp_path.ToString());
189     Path temp_path1 = dir / test_en;
190     multi30k_files_list.push_back(temp_path1.ToString());
191     Path temp_path2 = dir / valid_en;
192     multi30k_files_list.push_back(temp_path2.ToString());
193   }
194   return multi30k_files_list;
195 }
196 }  // namespace dataset
197 }  // namespace mindspore
198