• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/en_wik9_node.h"
18 
19 #include "minddata/dataset/engine/datasetops/source/en_wik9_op.h"
20 #include "minddata/dataset/util/status.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 // Constructor for EnWik9Node
EnWik9Node(const std::string & dataset_dir,int32_t num_samples,ShuffleMode shuffle,int32_t num_shards,int32_t shard_id,const std::shared_ptr<DatasetCache> & cache)25 EnWik9Node::EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
26                        int32_t shard_id, const std::shared_ptr<DatasetCache> &cache)
27     : NonMappableSourceNode(cache),
28       num_samples_(num_samples),
29       shuffle_(shuffle),
30       num_shards_(num_shards),
31       shard_id_(shard_id),
32       dataset_dir_(dataset_dir) {
33   // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
34   // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
35   // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
36   // PreBuildSampler is phased out, this can be cleaned up.
37   GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
38   DirToPath(dataset_dir_);
39 }
40 
Copy()41 std::shared_ptr<DatasetNode> EnWik9Node::Copy() {
42   auto node = std::make_shared<EnWik9Node>(dataset_dir_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
43   (void)node->SetNumWorkers(num_workers_);
44   (void)node->SetConnectorQueueSize(connector_que_size_);
45   return node;
46 }
47 
Print(std::ostream & out) const48 void EnWik9Node::Print(std::ostream &out) const {
49   out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
50           ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
51 }
52 
ValidateParams()53 Status EnWik9Node::ValidateParams() {
54   RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
55   RETURN_IF_NOT_OK(ValidateDatasetDirParam("EnWik9Dataset", dataset_dir_));
56   RETURN_IF_NOT_OK(ValidateEnum("EnWik9Dataset", "ShuffleMode", shuffle_,
57                                 {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
58   RETURN_IF_NOT_OK(ValidateScalar("EnWik9Dataset", "num_samples", num_samples_, {0}, false));
59   RETURN_IF_NOT_OK(ValidateDatasetShardParams("EnWik9Dataset", num_shards_, shard_id_));
60 
61   return Status::OK();
62 }
63 
DirToPath(const std::string & dataset_dir)64 void EnWik9Node::DirToPath(const std::string &dataset_dir) {
65   Path train_prefix("enwik9");
66   Path dir(dataset_dir);
67   Path temp_path = dir / train_prefix;
68   src_target_file_list_.push_back(temp_path.ToString());
69 }
70 
71 // Function to build EnWik9Node
Build(std::vector<std::shared_ptr<DatasetOp>> * const node_ops)72 Status EnWik9Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
73   bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
74   // Do internal Schema generation.
75   auto schema = std::make_unique<DataSchema>();
76   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
77 
78   // Create and initialize EnWik9Op
79   std::shared_ptr<EnWik9Op> en_wik9_op =
80     std::make_shared<EnWik9Op>(num_workers_, num_samples_, worker_connector_size_, std::move(schema),
81                                src_target_file_list_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
82   RETURN_IF_NOT_OK(en_wik9_op->Init());
83 
84   // If a global shuffle is used for EnWik9, it will inject a shuffle op over the EnWik9.
85   // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
86   // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset EnWik9's shuffle
87   // option to false.
88   if (shuffle_ == ShuffleMode::kGlobal) {
89     // Inject ShuffleOp
90     std::shared_ptr<ShuffleOp> shuffle_op = nullptr;
91     int64_t num_rows = 0;
92 
93     // First, get the number of rows in the dataset
94     RETURN_IF_NOT_OK(EnWik9Op::CountAllFileRows(src_target_file_list_, &num_rows));
95 
96     // Add the shuffle op after this op
97     RETURN_IF_NOT_OK(
98       AddShuffleOp(src_target_file_list_.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
99     shuffle_op->SetTotalRepeats(GetTotalRepeats());
100     shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
101     shuffle_op->Skip(skip_steps_);
102     node_ops->push_back(shuffle_op);
103   }
104   en_wik9_op->SetTotalRepeats(GetTotalRepeats());
105   en_wik9_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
106   // Add EnWik9Op
107   node_ops->push_back(en_wik9_op);
108   return Status::OK();
109 }
110 
111 // Get the shard id of node
GetShardId(int32_t * shard_id)112 Status EnWik9Node::GetShardId(int32_t *shard_id) {
113   *shard_id = shard_id_;
114   return Status::OK();
115 }
116 
117 // Get Dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)118 Status EnWik9Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
119                                   int64_t *dataset_size) {
120   if (dataset_size_ > 0) {
121     *dataset_size = dataset_size_;
122     return Status::OK();
123   }
124   int64_t num_rows, sample_size = num_samples_;
125   RETURN_IF_NOT_OK(EnWik9Op::CountAllFileRows(src_target_file_list_, &num_rows));
126   num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
127   *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
128   dataset_size_ = *dataset_size;
129   return Status::OK();
130 }
131 
to_json(nlohmann::json * out_json)132 Status EnWik9Node::to_json(nlohmann::json *out_json) {
133   nlohmann::json args;
134   args["num_parallel_workers"] = num_workers_;
135   args["connector_queue_size"] = connector_que_size_;
136   args["dataset_dir"] = dataset_dir_;
137   args["num_samples"] = num_samples_;
138   args["shuffle"] = shuffle_;
139   args["num_shards"] = num_shards_;
140   args["shard_id"] = shard_id_;
141   if (cache_ != nullptr) {
142     nlohmann::json cache_args;
143     RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
144     args["cache"] = cache_args;
145   }
146   *out_json = args;
147   return Status::OK();
148 }
149 
150 // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
151 // EnWik9 by itself is a non-mappable dataset that does not support sampling.
152 // However, if a cache operator is injected at some other place higher in the tree, that cache can
153 // inherit this sampler from the leaf, providing sampling support from the caching layer.
154 // That is why we setup the sampler for a leaf node that does not use sampling.
SetupSamplerForCache(std::shared_ptr<SamplerObj> * sampler)155 Status EnWik9Node::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
156   bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
157   *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
158   return Status::OK();
159 }
160 
161 // If a cache has been added into the ascendant tree over this EnWik9 node, then the cache will be executing
162 // a sampler for fetching the data. As such, any options in the EnWik9 node need to be reset to its defaults so
163 // that this EnWik9 node will produce the full set of data into the cache.
MakeSimpleProducer()164 Status EnWik9Node::MakeSimpleProducer() {
165   shard_id_ = 0;
166   num_shards_ = 1;
167   shuffle_ = ShuffleMode::kFalse;
168   num_samples_ = 0;
169   return Status::OK();
170 }
171 }  // namespace dataset
172 }  // namespace mindspore
173