1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/concat_op.h"
17
18 #include <algorithm>
19 #include <iomanip>
20 #include <utility>
21
22 #include "minddata/dataset/core/config_manager.h"
23 #include "minddata/dataset/util/random.h"
24 #include "utils/ms_utils.h"
25
26 namespace mindspore {
27 namespace dataset {
28 // Constructor of the ConcatOp.
ConcatOp(const std::shared_ptr<SamplerRT> & sampler,const std::vector<std::pair<int,int>> & children_flag_and_nums,const std::vector<std::pair<int,int>> & children_start_end_index,const std::vector<int64_t> & children_sizes)29 ConcatOp::ConcatOp(const std::shared_ptr<SamplerRT> &sampler,
30 const std::vector<std::pair<int, int>> &children_flag_and_nums,
31 const std::vector<std::pair<int, int>> &children_start_end_index,
32 const std::vector<int64_t> &children_sizes)
33 : ConcatOp() {
34 children_flag_and_nums_ = children_flag_and_nums;
35 children_start_end_index_ = children_start_end_index;
36 children_sizes_ = children_sizes;
37 children_sizes_ori_ = children_sizes;
38 std::shared_ptr<RandomSamplerRT> random_sampler = std::dynamic_pointer_cast<RandomSamplerRT>(sampler);
39 std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler);
40
41 if (random_sampler != nullptr && distribute_sampler == nullptr) {
42 // global sample mode
43 global_shuffle_ = true;
44 discrete_random_ = std::make_unique<std::discrete_distribution<>>(children_sizes_.begin(), children_sizes_.end());
45 rnd_.seed(seed_);
46 children_exhausted_ = std::vector<bool>(children_sizes_.size(), false);
47 } else if (random_sampler == nullptr && distribute_sampler != nullptr) {
48 // distributed sample mode
49 num_shard_ = static_cast<int32_t>(distribute_sampler->GetDeviceNum());
50 shard_index_ = static_cast<int32_t>(distribute_sampler->GetDeviceID());
51 }
52 }
53
ConcatOp()54 ConcatOp::ConcatOp()
55 : PipelineOp(0),
56 cur_child_(0),
57 verified_(false),
58 sample_number_(0),
59 num_shard_(1),
60 shard_index_(0),
61 discrete_random_(nullptr),
62 global_shuffle_(false),
63 seed_(GetSeed()) {}
64
65 // A function that prints info about the Operator
Print(std::ostream & out,bool show_all) const66 void ConcatOp::Print(std::ostream &out, bool show_all) const {
67 if (!show_all) {
68 // Call the super class for displaying any common 1-liner info
69 PipelineOp::Print(out, show_all);
70 // Then show any custom derived-internal 1-liner info for this op
71 out << "\n";
72 } else {
73 // Call the super class for displaying any common detailed info
74 PipelineOp::Print(out, show_all);
75 // Then show any custom derived-internal stuff
76 out << "\nDatasets: " << child_.size() << "\n\n";
77 }
78 }
79
80 // This definition is added to pass the cyclomatic complexity rule of <= 20 units
81 // The NOLINT directive is to disable cpplint check.
82 // Clang format and cpplint give conflicting recommendations on this line below.
83 #define f(fv, sv, shard_index) \
84 (((fv) == -1 && (sv) == -1) || ((fv) < (sv) && (shard_index) >= (fv) && (shard_index) < (sv)) || \
85 ((fv) > (sv) && ((shard_index) >= (fv) || (shard_index) < (sv)))) // NOLINT
86
Verify(int32_t id,const TensorRow & new_row)87 Status ConcatOp::Verify(int32_t id, const TensorRow &new_row) {
88 if (id == 0) {
89 // Obtain the data type and data rank in child[0]
90 for (auto item : new_row) {
91 data_type_.push_back(item->type());
92 data_rank_.push_back(item->Rank());
93 }
94 } else {
95 // Compare the data type and data rank with these in child[0]
96 int32_t index = 0;
97 for (auto item : new_row) {
98 if (item->type() != data_type_[index]) {
99 RETURN_STATUS_UNEXPECTED(
100 "Concat: the data types of the two datasets to be concatenated should be the same, but got: " +
101 data_type_[index].ToString() + " and " + item->type().ToString() + ".");
102 }
103 if (item->Rank() != data_rank_[index]) {
104 RETURN_STATUS_UNEXPECTED(
105 "Concat: the data tensor rank of the two datasets to be concatenated should be the same, but got: " +
106 std::to_string(data_rank_[index]) + " and " + std::to_string(item->Rank()) + ".");
107 }
108 index++;
109 }
110 }
111 verified_ = true;
112 return Status::OK();
113 }
114
115 // We need to overwrite the super class ComputeColMap here because the number of children is more than 1.
ComputeColMap()116 Status ConcatOp::ComputeColMap() {
117 if (column_name_id_map_.empty()) {
118 // Obtain columns_name_id_map from child_[0]
119 column_name_id_map_ = child_[0]->column_name_id_map();
120 if (column_name_id_map_.empty()) {
121 RETURN_STATUS_UNEXPECTED("[Internal ERROR] Child column name map cannot be empty!");
122 }
123 // Verify all children have the same column name map
124 for (size_t i = 0; i < child_.size(); ++i) {
125 if (child_[i]->column_name_id_map() != column_name_id_map_) {
126 RETURN_STATUS_UNEXPECTED(
127 "Invalid columns, 'column name' or 'column order' of concat datasets should be the same.");
128 }
129 }
130 } else {
131 MS_LOG(WARNING) << "Column name map is already set!";
132 }
133 return Status::OK();
134 }
135
136 // Gets the number of classes
GetNumClasses(int64_t * num_classes)137 Status ConcatOp::GetNumClasses(int64_t *num_classes) {
138 RETURN_UNEXPECTED_IF_NULL(num_classes);
139 int64_t max_num_classes = -1;
140 for (const auto &child : child_) {
141 // Choose a dataset which can get valid num_classes
142 int64_t tmp_num_classes = -1;
143 RETURN_IF_NOT_OK(child->GetNumClasses(&tmp_num_classes));
144 if (tmp_num_classes > max_num_classes) {
145 max_num_classes = tmp_num_classes;
146 }
147 }
148 *num_classes = max_num_classes;
149 return Status::OK();
150 }
operator ()()151 Status ConcatOp::operator()() { RETURN_STATUS_UNEXPECTED("[Internal ERROR] ConcatOp is an inlined operator."); }
152
IgnoreSample()153 bool ConcatOp::IgnoreSample() {
154 bool is_not_mappable_or_second_ne_zero = true;
155
156 if (!children_flag_and_nums_.empty()) {
157 const bool is_not_mappable = children_flag_and_nums_[cur_child_].first != 0 ? true : false;
158 const bool second_ne_zero = children_flag_and_nums_[cur_child_].second == 0 ? true : false;
159 is_not_mappable_or_second_ne_zero = is_not_mappable || second_ne_zero;
160 }
161 bool ret = true;
162 if (sample_number_ % num_shard_ == shard_index_ && is_not_mappable_or_second_ne_zero) {
163 ret = false;
164 } else if (!is_not_mappable_or_second_ne_zero) {
165 // if dataset is mappable or generator dataset which source is not yield,
166 // get the start and end subscripts of valid values
167 int fv = children_start_end_index_[cur_child_].first, sv = children_start_end_index_[cur_child_].second;
168
169 // determine whether the data allocated to the current shard id is false data
170 if (f(fv, sv, shard_index_)) {
171 ret = false;
172 }
173 }
174
175 if (is_not_mappable_or_second_ne_zero) {
176 sample_number_++;
177 }
178 return ret;
179 }
180
SampleInSequence(TensorRow * row,bool is_pipeline_mode)181 Status ConcatOp::SampleInSequence(TensorRow *row, bool is_pipeline_mode) {
182 row->reset();
183 bool is_not_mappable_or_second_ne_zero = true;
184
185 if (!children_flag_and_nums_.empty()) {
186 const bool is_not_mappable = children_flag_and_nums_[cur_child_].first != 0 ? true : false;
187 const bool second_ne_zero = children_flag_and_nums_[cur_child_].second == 0 ? true : false;
188 // unmappable or iterable Generator
189 is_not_mappable_or_second_ne_zero = is_not_mappable || second_ne_zero;
190 }
191
192 RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
193 Status s = is_pipeline_mode ? child_[cur_child_]->GetNextRow(row) : child_[cur_child_]->GetNextRowPullMode(row);
194 RETURN_IF_NOT_OK(s);
195 RETURN_IF_NOT_OK(CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", row->FlagName()}}));
196
197 if (!row->eoe() && !row->eof()) {
198 if (!verified_) {
199 RETURN_IF_NOT_OK(Verify(static_cast<int32_t>(cur_child_), *row));
200 }
201
202 if (IgnoreSample()) {
203 s = is_pipeline_mode ? GetNextRow(row) : GetNextRowPullMode(row);
204 RETURN_IF_NOT_OK(s);
205 }
206 } else if (row->eoe()) {
207 // if last child, send out eoe and reset epoch
208 if (cur_child_ == child_.size() - 1) {
209 // reset
210 cur_child_ = 0;
211 verified_ = false;
212 UpdateRepeatAndEpochCounter();
213 } else {
214 // mappable and mappable Generator
215 if (!is_not_mappable_or_second_ne_zero) {
216 sample_number_ += children_flag_and_nums_[cur_child_].second;
217 }
218 cur_child_++;
219 verified_ = false;
220 s = is_pipeline_mode ? GetNextRow(row) : GetNextRowPullMode(row);
221 RETURN_IF_NOT_OK(s);
222 }
223 } else if (row->eof()) {
224 CHECK_FAIL_RETURN_UNEXPECTED(cur_child_ == 0, "[Internal ERROR] Received an unexpected EOF.");
225 for (size_t i = cur_child_ + 1; i < child_.size(); i++) {
226 RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
227 s = is_pipeline_mode ? child_[i]->GetNextRow(row) : child_[i]->GetNextRowPullMode(row);
228 RETURN_IF_NOT_OK(s);
229 RETURN_IF_NOT_OK(
230 CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", row->FlagName()}}));
231 CHECK_FAIL_RETURN_UNEXPECTED(row->eof(), "[Internal ERROR] Row must be an EOF.");
232 }
233 }
234 return Status::OK();
235 }
236
SampleInGlobal(TensorRow * row,bool is_pipeline_mode)237 Status ConcatOp::SampleInGlobal(TensorRow *row, bool is_pipeline_mode) {
238 row->reset();
239 // select child id
240 auto child_id = (*discrete_random_)(rnd_);
241 MS_LOG(INFO) << "sample from child " << child_id;
242
243 RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
244 Status s = is_pipeline_mode ? child_[child_id]->GetNextRow(row) : child_[child_id]->GetNextRowPullMode(row);
245 RETURN_IF_NOT_OK(s);
246 RETURN_IF_NOT_OK(CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", row->FlagName()}}));
247
248 if (!row->eoe() && !row->eof()) {
249 // normal case, but reduce total sample numbers for without replacement sampling
250 children_sizes_[child_id] = std::max(--children_sizes_[child_id], static_cast<int64_t>(0));
251 // change distribution
252 discrete_random_->param({children_sizes_.begin(), children_sizes_.end()});
253 return Status::OK();
254 } else if (row->eoe()) {
255 // if one child has drained, sample from other children
256 children_exhausted_[child_id] = true;
257 children_sizes_[child_id] = 0;
258
259 // check all children have been exhausted
260 MS_LOG(INFO) << "child " << child_id << " node has been drained, check all children status (next row is eoe).";
261 // always get eoe from child 0, since random {0, 0, 0} return 0, exhaust other children
262 for (auto c = 0; c < children_exhausted_.size(); c++) {
263 TensorRow eoe;
264 if (!children_exhausted_[c]) {
265 RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
266 s = is_pipeline_mode ? child_[c]->GetNextRow(&eoe) : child_[c]->GetNextRowPullMode(&eoe);
267 RETURN_IF_NOT_OK(s);
268 RETURN_IF_NOT_OK(
269 CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", eoe.FlagName()}}));
270 // for those variable dataset size, we cannot support currently
271 CHECK_FAIL_RETURN_UNEXPECTED(
272 eoe.eoe(),
273 "The actual size of dataset " + std::to_string(c) +
274 " does not match its defined size, maybe the dataset size is variable or `__len__` is incorrect.");
275 children_exhausted_[c] = true;
276 }
277 }
278
279 // reset distribution
280 MS_LOG(INFO) << "reset all children.";
281 children_sizes_ = children_sizes_ori_;
282 children_exhausted_ = std::vector<bool>(children_sizes_.size(), false);
283 discrete_random_->param({children_sizes_.begin(), children_sizes_.end()});
284 UpdateRepeatAndEpochCounter();
285 } else if (row->eof()) {
286 // check all children have been drained
287 MS_LOG(INFO) << "Get eof from child " << child_id << ", drain eof of other children";
288 for (size_t i = 0; i < child_.size(); i++) {
289 if (i != child_id) {
290 RETURN_IF_NOT_OK(CollectOpInfoStart(this->NameWithID(), "GetFromPreviousOp"));
291 s = is_pipeline_mode ? child_[i]->GetNextRow(row) : child_[i]->GetNextRowPullMode(row);
292 RETURN_IF_NOT_OK(s);
293 RETURN_IF_NOT_OK(
294 CollectOpInfoEnd(this->NameWithID(), "GetFromPreviousOp", {{"TensorRowFlags", row->FlagName()}}));
295 CHECK_FAIL_RETURN_UNEXPECTED(row->eof(), "[Internal ERROR] Row must be an EOF.");
296 }
297 }
298 }
299 return Status::OK();
300 }
301
GetNextRow(TensorRow * row)302 Status ConcatOp::GetNextRow(TensorRow *row) {
303 RETURN_UNEXPECTED_IF_NULL(row);
304 if (global_shuffle_) {
305 RETURN_IF_NOT_OK(SampleInGlobal(row));
306 } else {
307 RETURN_IF_NOT_OK(SampleInSequence(row));
308 }
309 return Status::OK();
310 }
311
GetNextRowPullMode(TensorRow * const row)312 Status ConcatOp::GetNextRowPullMode(TensorRow *const row) {
313 RETURN_UNEXPECTED_IF_NULL(row);
314 if (global_shuffle_) {
315 RETURN_IF_NOT_OK(SampleInGlobal(row, false));
316 } else {
317 RETURN_IF_NOT_OK(SampleInSequence(row, false));
318 }
319 return Status::OK();
320 }
321 } // namespace dataset
322 } // namespace mindspore
323