• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
18 
19 #include "minddata/dataset/engine/datasetops/source/udpos_op.h"
20 #include "minddata/dataset/util/status.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 // Constructor for UDPOSNode.
UDPOSNode(const std::string & dataset_dir,const std::string & usage,int32_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,std::shared_ptr<DatasetCache> cache)25 UDPOSNode::UDPOSNode(const std::string &dataset_dir, const std::string &usage, int32_t num_samples, ShuffleMode shuffle,
26                      int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
27     : NonMappableSourceNode(std::move(cache)),
28       dataset_dir_(dataset_dir),
29       usage_(usage),
30       num_samples_(num_samples),
31       shuffle_(shuffle),
32       num_shards_(num_shards),
33       shard_id_(shard_id),
34       udpos_files_list_(WalkAllFiles(usage, dataset_dir)) {
35   // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
36   // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
37   // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
38   // PreBuildSampler is phased out, this can be cleaned up.
39   GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
40 }
41 
Copy()42 std::shared_ptr<DatasetNode> UDPOSNode::Copy() {
43   auto node = std::make_shared<UDPOSNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
44   (void)node->SetNumWorkers(num_workers_);
45   (void)node->SetConnectorQueueSize(connector_que_size_);
46   return node;
47 }
48 
Print(std::ostream & out) const49 void UDPOSNode::Print(std::ostream &out) const {
50   out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
51           ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
52 }
53 
ValidateParams()54 Status UDPOSNode::ValidateParams() {
55   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
56   RETURN_IF_NOT_OK(ValidateDatasetDirParam("UDPOSNode", dataset_dir_));
57   RETURN_IF_NOT_OK(ValidateStringValue("UDPOSNode", usage_, {"train", "test", "valid", "all"}));
58   RETURN_IF_NOT_OK(ValidateScalar("UDPOSNode", "num_samples", num_samples_, {0}, false));
59   RETURN_IF_NOT_OK(ValidateDatasetShardParams("UDPOSNode", num_shards_, shard_id_));
60   RETURN_IF_NOT_OK(ValidateEnum("UDPOSNode", "ShuffleMode", shuffle_,
61                                 {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
62   return Status::OK();
63 }
64 
65 // Function to build UDPOSNode.
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)66 Status UDPOSNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
67   bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
68 
69   // Sort the dataset files in a lexicographical order.
70   std::vector<std::string> sorted_dataset_files = udpos_files_list_;
71   std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
72 
73   // Do internal Schema generation.
74   auto schema = std::make_unique<DataSchema>();
75   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("word", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
76   TensorShape scalar = TensorShape::CreateScalar();
77   RETURN_IF_NOT_OK(
78     schema->AddColumn(ColDescriptor("universal", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
79   RETURN_IF_NOT_OK(
80     schema->AddColumn(ColDescriptor("stanford", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
81 
82   // Create and initialize UDPOSOp.
83   std::shared_ptr<UDPOSOp> udpos_op =
84     std::make_shared<UDPOSOp>(num_workers_, num_samples_, worker_connector_size_, std::move(schema),
85                               sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
86   RETURN_IF_NOT_OK(udpos_op->Init());
87 
88   // If a global shuffle is used for UDPOS, it will inject a shuffle op over the UDPOS.
89   // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
90   // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset UDPOS's shuffle
91   // option to false.
92   if (shuffle_ == ShuffleMode::kGlobal) {
93     // Inject ShuffleOp.
94     std::shared_ptr<ShuffleOp> shuffle_op = nullptr;
95     int64_t num_rows = 0;
96 
97     // First, get the number of rows in the dataset.
98     RETURN_IF_NOT_OK(UDPOSOp::CountAllFileRows(sorted_dataset_files, &num_rows));
99 
100     // Add the shuffle op after this op.
101     RETURN_IF_NOT_OK(
102       AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
103     shuffle_op->SetTotalRepeats(GetTotalRepeats());
104     shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
105     shuffle_op->Skip(skip_steps_);
106     node_ops->push_back(shuffle_op);
107   }
108   udpos_op->SetTotalRepeats(GetTotalRepeats());
109   udpos_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
110   // Add UDPOSOp.
111   node_ops->push_back(udpos_op);
112 
113   return Status::OK();
114 }
115 
116 // Get the shard id of node.
GetShardId(int32_t * shard_id)117 Status UDPOSNode::GetShardId(int32_t *shard_id) {
118   *shard_id = shard_id_;
119 
120   return Status::OK();
121 }
122 
123 // Get Dataset size.
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)124 Status UDPOSNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
125                                  int64_t *dataset_size) {
126   if (dataset_size_ > 0) {
127     *dataset_size = dataset_size_;
128     return Status::OK();
129   }
130   int64_t num_rows, sample_size = num_samples_;
131   RETURN_IF_NOT_OK(UDPOSOp::CountAllFileRows(udpos_files_list_, &num_rows));
132   num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
133   *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
134   dataset_size_ = *dataset_size;
135   return Status::OK();
136 }
137 
to_json(nlohmann::json * out_json)138 Status UDPOSNode::to_json(nlohmann::json *out_json) {
139   nlohmann::json args;
140   args["num_parallel_workers"] = num_workers_;
141   args["connector_queue_size"] = connector_que_size_;
142   args["dataset_dir"] = dataset_dir_;
143   args["usage"] = usage_;
144   args["num_samples"] = num_samples_;
145   args["shuffle"] = shuffle_;
146   args["num_shards"] = num_shards_;
147   args["shard_id"] = shard_id_;
148   if (cache_ != nullptr) {
149     nlohmann::json cache_args;
150     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
151     args["cache"] = cache_args;
152   }
153   *out_json = args;
154   return Status::OK();
155 }
156 
157 // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
158 // UDPOS by itself is a non-mappable dataset that does not support sampling.
159 // However, if a cache operator is injected at some other place higher in the tree, that cache can
160 // inherit this sampler from the leaf, providing sampling support from the caching layer.
161 // That is why we setup the sampler for a leaf node that does not use sampling.
SetupSamplerForCache(std::shared_ptr<SamplerObj> * sampler)162 Status UDPOSNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
163   bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
164   *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
165   return Status::OK();
166 }
167 
168 // If a cache has been added into the ascendant tree over this UDPOS node, then the cache will be executing
169 // a sampler for fetching the data. As such, any options in the UDPOS node need to be reset to its defaults so
170 // that this UDPOS node will produce the full set of data into the cache.
MakeSimpleProducer()171 Status UDPOSNode::MakeSimpleProducer() {
172   shard_id_ = 0;
173   num_shards_ = 1;
174   shuffle_ = ShuffleMode::kFalse;
175   num_samples_ = 0;
176   return Status::OK();
177 }
178 
WalkAllFiles(const std::string & usage,const std::string & dataset_dir)179 std::vector<std::string> UDPOSNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
180   std::vector<std::string> udpos_files_list;
181   Path train_prefix("en-ud-tag.v2.train.txt");
182   Path test_prefix("en-ud-tag.v2.test.txt");
183   Path valid_prefix("en-ud-tag.v2.dev.txt");
184   Path dir(dataset_dir);
185 
186   if (usage == "train") {
187     Path temp_path = dir / train_prefix;
188     udpos_files_list.push_back(temp_path.ToString());
189   } else if (usage == "test") {
190     Path temp_path = dir / test_prefix;
191     udpos_files_list.push_back(temp_path.ToString());
192   } else if (usage == "valid") {
193     Path temp_path = dir / valid_prefix;
194     udpos_files_list.push_back(temp_path.ToString());
195   } else {
196     Path temp_path = dir / train_prefix;
197     udpos_files_list.push_back(temp_path.ToString());
198     Path temp_path1 = dir / test_prefix;
199     udpos_files_list.push_back(temp_path1.ToString());
200     Path temp_path2 = dir / valid_prefix;
201     udpos_files_list.push_back(temp_path2.ToString());
202   }
203   return udpos_files_list;
204 }
205 }  // namespace dataset
206 }  // namespace mindspore
207