• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/concat_op.h"
17 
18 #include <algorithm>
19 #include <iomanip>
20 #include <utility>
21 
22 #include "minddata/dataset/core/config_manager.h"
23 #include "minddata/dataset/util/random.h"
24 #include "utils/ms_utils.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 // Constructor of the ConcatOp.
ConcatOp(const std::shared_ptr<SamplerRT> & sampler,const std::vector<std::pair<int,int>> & children_flag_and_nums,const std::vector<std::pair<int,int>> & children_start_end_index,const std::vector<int64_t> & children_sizes)29 ConcatOp::ConcatOp(const std::shared_ptr<SamplerRT> &sampler,
30                    const std::vector<std::pair<int, int>> &children_flag_and_nums,
31                    const std::vector<std::pair<int, int>> &children_start_end_index,
32                    const std::vector<int64_t> &children_sizes)
33     : ConcatOp() {
34   children_flag_and_nums_ = children_flag_and_nums;
35   children_start_end_index_ = children_start_end_index;
36   children_sizes_ = children_sizes;
37   children_sizes_ori_ = children_sizes;
38   std::shared_ptr<RandomSamplerRT> random_sampler = std::dynamic_pointer_cast<RandomSamplerRT>(sampler);
39   std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler);
40 
41   if (random_sampler != nullptr && distribute_sampler == nullptr) {
42     // global sample mode
43     global_shuffle_ = true;
44     discrete_random_ = std::make_unique<std::discrete_distribution<>>(children_sizes_.begin(), children_sizes_.end());
45     rnd_.seed(seed_);
46     children_exhausted_ = std::vector<bool>(children_sizes_.size(), false);
47   } else if (random_sampler == nullptr && distribute_sampler != nullptr) {
48     // distributed sample mode
49     num_shard_ = static_cast<int32_t>(distribute_sampler->GetDeviceNum());
50     shard_index_ = static_cast<int32_t>(distribute_sampler->GetDeviceID());
51   }
52 }
53 
ConcatOp()54 ConcatOp::ConcatOp()
55     : PipelineOp(0),
56       cur_child_(0),
57       verified_(false),
58       sample_number_(0),
59       num_shard_(1),
60       shard_index_(0),
61       discrete_random_(nullptr),
62       global_shuffle_(false),
63       seed_(GetSeed()) {}
64 
65 // A function that prints info about the Operator
Print(std::ostream & out,bool show_all) const66 void ConcatOp::Print(std::ostream &out, bool show_all) const {
67   if (!show_all) {
68     // Call the super class for displaying any common 1-liner info
69     PipelineOp::Print(out, show_all);
70     // Then show any custom derived-internal 1-liner info for this op
71     out << "\n";
72   } else {
73     // Call the super class for displaying any common detailed info
74     PipelineOp::Print(out, show_all);
75     // Then show any custom derived-internal stuff
76     out << "\nDatasets: " << child_.size() << "\n\n";
77   }
78 }
79 
80 // This definition is added to pass the cyclomatic complexity rule of <= 20 units
81 // The NOLINT directive is to disable cpplint check.
82 // Clang format and cpplint give conflicting recommendations on this line below.
83 #define f(fv, sv, shard_index)                                                                     \
84   (((fv) == -1 && (sv) == -1) || ((fv) < (sv) && (shard_index) >= (fv) && (shard_index) < (sv)) || \
85    ((fv) > (sv) && ((shard_index) >= (fv) || (shard_index) < (sv))))  // NOLINT
86 
Verify(int32_t id,const TensorRow & new_row)87 Status ConcatOp::Verify(int32_t id, const TensorRow &new_row) {
88   if (id == 0) {
89     // Obtain the data type and data rank in child[0]
90     for (auto item : new_row) {
91       data_type_.push_back(item->type());
92       data_rank_.push_back(item->Rank());
93     }
94   } else {
95     // Compare the data type and data rank with these in child[0]
96     int32_t index = 0;
97     for (auto item : new_row) {
98       if (item->type() != data_type_[index]) {
99         RETURN_STATUS_UNEXPECTED(
100           "Concat: the data types of the two datasets to be concatenated should be the same, but got: " +
101           data_type_[index].ToString() + " and " + item->type().ToString() + ".");
102       }
103       if (item->Rank() != data_rank_[index]) {
104         RETURN_STATUS_UNEXPECTED(
105           "Concat: the data tensor rank of the two datasets to be concatenated should be the same, but got: " +
106           std::to_string(data_rank_[index]) + " and " + std::to_string(item->Rank()) + ".");
107       }
108       index++;
109     }
110   }
111   verified_ = true;
112   return Status::OK();
113 }
114 
115 // We need to overwrite the super class ComputeColMap here because the number of children is more than 1.
ComputeColMap()116 Status ConcatOp::ComputeColMap() {
117   if (column_name_id_map_.empty()) {
118     // Obtain columns_name_id_map from child_[0]
119     column_name_id_map_ = child_[0]->column_name_id_map();
120     if (column_name_id_map_.empty()) {
121       RETURN_STATUS_UNEXPECTED("[Internal ERROR] Child column name map cannot be empty!");
122     }
123     // Verify all children have the same column name map
124     for (size_t i = 0; i < child_.size(); ++i) {
125       if (child_[i]->column_name_id_map() != column_name_id_map_) {
126         RETURN_STATUS_UNEXPECTED(
127           "Invalid columns, 'column name' or 'column order' of concat datasets should be the same.");
128       }
129     }
130   } else {
131     MS_LOG(WARNING) << "Column name map is already set!";
132   }
133   return Status::OK();
134 }
135 
136 // Gets the number of classes
GetNumClasses(int64_t * num_classes)137 Status ConcatOp::GetNumClasses(int64_t *num_classes) {
138   RETURN_UNEXPECTED_IF_NULL(num_classes);
139   int64_t max_num_classes = -1;
140   for (const auto &child : child_) {
141     // Choose a dataset which can get valid num_classes
142     int64_t tmp_num_classes = -1;
143     RETURN_IF_NOT_OK(child->GetNumClasses(&tmp_num_classes));
144     if (tmp_num_classes > max_num_classes) {
145       max_num_classes = tmp_num_classes;
146     }
147   }
148   *num_classes = max_num_classes;
149   return Status::OK();
150 }
operator ()()151 Status ConcatOp::operator()() { RETURN_STATUS_UNEXPECTED("[Internal ERROR] ConcatOp is an inlined operator."); }
152 
IgnoreSample()153 bool ConcatOp::IgnoreSample() {
154   bool is_not_mappable_or_second_ne_zero = true;
155 
156   if (!children_flag_and_nums_.empty()) {
157     const bool is_not_mappable = children_flag_and_nums_[cur_child_].first != 0 ? true : false;
158     const bool second_ne_zero = children_flag_and_nums_[cur_child_].second == 0 ? true : false;
159     is_not_mappable_or_second_ne_zero = is_not_mappable || second_ne_zero;
160   }
161   bool ret = true;
162   if (sample_number_ % num_shard_ == shard_index_ && is_not_mappable_or_second_ne_zero) {
163     ret = false;
164   } else if (!is_not_mappable_or_second_ne_zero) {
165     // if dataset is mappable or generator dataset which source is not yield,
166     // get the start and end subscripts of valid values
167     int fv = children_start_end_index_[cur_child_].first, sv = children_start_end_index_[cur_child_].second;
168 
169     // determine whether the data allocated to the current shard id is false data
170     if (f(fv, sv, shard_index_)) {
171       ret = false;
172     }
173   }
174 
175   if (is_not_mappable_or_second_ne_zero) {
176     sample_number_++;
177   }
178   return ret;
179 }
180 
SampleInSequence(TensorRow * row,bool is_pipeline_mode)181 Status ConcatOp::SampleInSequence(TensorRow *row, bool is_pipeline_mode) {
182   row->reset();
183   bool is_not_mappable_or_second_ne_zero = true;
184 
185   if (!children_flag_and_nums_.empty()) {
186     const bool is_not_mappable = children_flag_and_nums_[cur_child_].first != 0 ? true : false;
187     const bool second_ne_zero = children_flag_and_nums_[cur_child_].second == 0 ? true : false;
188     // unmappable or iterable Generator
189     is_not_mappable_or_second_ne_zero = is_not_mappable || second_ne_zero;
190   }
191 
192   RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
193   Status s = is_pipeline_mode ? child_[cur_child_]->GetNextRow(row) : child_[cur_child_]->GetNextRowPullMode(row);
194   RETURN_IF_NOT_OK(s);
195   RETURN_IF_NOT_OK(CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", row->FlagName()}}));
196 
197   if (!row->eoe() && !row->eof()) {
198     if (!verified_) {
199       RETURN_IF_NOT_OK(Verify(static_cast<int32_t>(cur_child_), *row));
200     }
201 
202     if (IgnoreSample()) {
203       s = is_pipeline_mode ? GetNextRow(row) : GetNextRowPullMode(row);
204       RETURN_IF_NOT_OK(s);
205     }
206   } else if (row->eoe()) {
207     // if last child, send out eoe and reset epoch
208     if (cur_child_ == child_.size() - 1) {
209       // reset
210       cur_child_ = 0;
211       verified_ = false;
212       UpdateRepeatAndEpochCounter();
213     } else {
214       // mappable and mappable Generator
215       if (!is_not_mappable_or_second_ne_zero) {
216         sample_number_ += children_flag_and_nums_[cur_child_].second;
217       }
218       cur_child_++;
219       verified_ = false;
220       s = is_pipeline_mode ? GetNextRow(row) : GetNextRowPullMode(row);
221       RETURN_IF_NOT_OK(s);
222     }
223   } else if (row->eof()) {
224     CHECK_FAIL_RETURN_UNEXPECTED(cur_child_ == 0, "[Internal ERROR] Received an unexpected EOF.");
225     for (size_t i = cur_child_ + 1; i < child_.size(); i++) {
226       RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
227       s = is_pipeline_mode ? child_[i]->GetNextRow(row) : child_[i]->GetNextRowPullMode(row);
228       RETURN_IF_NOT_OK(s);
229       RETURN_IF_NOT_OK(
230         CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", row->FlagName()}}));
231       CHECK_FAIL_RETURN_UNEXPECTED(row->eof(), "[Internal ERROR] Row must be an EOF.");
232     }
233   }
234   return Status::OK();
235 }
236 
SampleInGlobal(TensorRow * row,bool is_pipeline_mode)237 Status ConcatOp::SampleInGlobal(TensorRow *row, bool is_pipeline_mode) {
238   row->reset();
239   // select child id
240   auto child_id = (*discrete_random_)(rnd_);
241   MS_LOG(INFO) << "sample from child " << child_id;
242 
243   RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
244   Status s = is_pipeline_mode ? child_[child_id]->GetNextRow(row) : child_[child_id]->GetNextRowPullMode(row);
245   RETURN_IF_NOT_OK(s);
246   RETURN_IF_NOT_OK(CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", row->FlagName()}}));
247 
248   if (!row->eoe() && !row->eof()) {
249     // normal case, but reduce total sample numbers for without replacement sampling
250     children_sizes_[child_id] = std::max(--children_sizes_[child_id], static_cast<int64_t>(0));
251     // change distribution
252     discrete_random_->param({children_sizes_.begin(), children_sizes_.end()});
253     return Status::OK();
254   } else if (row->eoe()) {
255     // if one child has drained, sample from other children
256     children_exhausted_[child_id] = true;
257     children_sizes_[child_id] = 0;
258 
259     // check all children have been exhausted
260     MS_LOG(INFO) << "child " << child_id << " node has been drained, check all children status (next row is eoe).";
261     // always get eoe from child 0, since random {0, 0, 0} return 0, exhaust other children
262     for (auto c = 0; c < children_exhausted_.size(); c++) {
263       TensorRow eoe;
264       if (!children_exhausted_[c]) {
265         RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
266         s = is_pipeline_mode ? child_[c]->GetNextRow(&eoe) : child_[c]->GetNextRowPullMode(&eoe);
267         RETURN_IF_NOT_OK(s);
268         RETURN_IF_NOT_OK(
269           CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", eoe.FlagName()}}));
270         // for those variable dataset size, we cannot support currently
271         CHECK_FAIL_RETURN_UNEXPECTED(
272           eoe.eoe(),
273           "The actual size of dataset " + std::to_string(c) +
274             " does not match its defined size, maybe the dataset size is variable or `__len__` is incorrect.");
275         children_exhausted_[c] = true;
276       }
277     }
278 
279     // reset distribution
280     MS_LOG(INFO) << "reset all children.";
281     children_sizes_ = children_sizes_ori_;
282     children_exhausted_ = std::vector<bool>(children_sizes_.size(), false);
283     discrete_random_->param({children_sizes_.begin(), children_sizes_.end()});
284     UpdateRepeatAndEpochCounter();
285   } else if (row->eof()) {
286     // check all children have been drained
287     MS_LOG(INFO) << "Get eof from child " << child_id << ", drain eof of other children";
288     for (size_t i = 0; i < child_.size(); i++) {
289       if (i != child_id) {
290         RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
291         s = is_pipeline_mode ? child_[i]->GetNextRow(row) : child_[i]->GetNextRowPullMode(row);
292         RETURN_IF_NOT_OK(s);
293         RETURN_IF_NOT_OK(
294           CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", row->FlagName()}}));
295         CHECK_FAIL_RETURN_UNEXPECTED(row->eof(), "[Internal ERROR] Row must be an EOF.");
296       }
297     }
298   }
299   return Status::OK();
300 }
301 
GetNextRow(TensorRow * row)302 Status ConcatOp::GetNextRow(TensorRow *row) {
303   RETURN_UNEXPECTED_IF_NULL(row);
304   if (global_shuffle_) {
305     RETURN_IF_NOT_OK(SampleInGlobal(row));
306   } else {
307     RETURN_IF_NOT_OK(SampleInSequence(row));
308   }
309   return Status::OK();
310 }
311 
GetNextRowPullMode(TensorRow * const row)312 Status ConcatOp::GetNextRowPullMode(TensorRow *const row) {
313   RETURN_UNEXPECTED_IF_NULL(row);
314   if (global_shuffle_) {
315     RETURN_IF_NOT_OK(SampleInGlobal(row, false));
316   } else {
317     RETURN_IF_NOT_OK(SampleInSequence(row, false));
318   }
319   return Status::OK();
320 }
321 }  // namespace dataset
322 }  // namespace mindspore
323