1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_shuffle.h"
18
19 #include <algorithm>
20
21 namespace mindspore {
22 namespace mindrecord {
ShardShuffle(uint32_t seed,ShuffleType shuffle_type)23 ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
24 : shuffle_seed_(seed),
25 no_of_samples_(0),
26 replacement_(false),
27 reshuffle_each_epoch_(true),
28 shuffle_type_(shuffle_type) {}
29
ShardShuffle(uint32_t seed,int64_t no_of_samples,bool replacement,bool reshuffle_each_epoch,ShuffleType shuffle_type)30 ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
31 ShuffleType shuffle_type)
32 : shuffle_seed_(seed),
33 no_of_samples_(no_of_samples),
34 replacement_(replacement),
35 reshuffle_each_epoch_(reshuffle_each_epoch),
36 shuffle_type_(shuffle_type) {}
37
GetNumSamples(int64_t dataset_size,int64_t num_classes)38 int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
39 if (replacement_) {
40 return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
41 }
42 return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
43 }
44
CategoryShuffle(ShardTaskList & tasks)45 Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
46 int64_t individual_size = tasks.sample_ids_.size() / tasks.categories;
47 std::vector<std::vector<int64_t>> new_permutations(tasks.categories, std::vector<int64_t>(individual_size));
48 for (int64_t i = 0; i < tasks.categories; i++) {
49 for (int64_t j = 0; j < individual_size; j++) {
50 new_permutations[i][j] = j;
51 }
52 std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
53 }
54 tasks.permutation_.clear();
55 for (int64_t j = 0; j < individual_size; j++) {
56 for (int64_t i = 0; i < tasks.categories; i++) {
57 tasks.permutation_.push_back(new_permutations[i][j] * tasks.categories + i);
58 }
59 }
60
61 ShardTaskList new_tasks;
62 for (int64_t i = 0; i < individual_size; ++i) {
63 new_tasks.AssignTask(tasks, tasks.permutation_[i]);
64 }
65 ShardTaskList::TaskListSwap(tasks, new_tasks);
66
67 return Status::OK();
68 }
69
ShuffleFiles(ShardTaskList & tasks)70 Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
71 if (no_of_samples_ == 0) {
72 no_of_samples_ = tasks.Size();
73 }
74 CHECK_FAIL_RETURN_UNEXPECTED_MR(
75 no_of_samples_ > 0, "Invalid input, 'num_samples' should be positive but got: " + std::to_string(no_of_samples_));
76 auto shard_sample_cout = GetShardSampleCount();
77
78 // shuffle the files index
79 std::vector<int64_t> shuffle_files;
80 for (int64_t i = 0; i < shard_sample_cout.size(); i++) {
81 shuffle_files.push_back(i);
82 }
83 std::shuffle(shuffle_files.begin(), shuffle_files.end(), std::default_random_engine(shuffle_seed_));
84
85 // reconstruct the permutation between files
86 // -- before --
87 // file1: [0, 1, 2]
88 // file2: [3, 4, 5, 6]
89 // file3: [7, 8]
90 // file4: [9, 10]
91 // files: [file1, file2, file3, file4]
92 // permutation: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
93 // -- after --
94 // files: [file4, file1, file3, file2]
95 // permutation : [9, 10, 0, 1, 2, 7, 8, 3, 4, 5, 6]
96 auto original_permutation = tasks.permutation_;
97 int64_t whole_index = 0;
98 for (int64_t i = 0; i < shuffle_files.size(); i++) {
99 int64_t start_index = 0;
100 int64_t current_size = 0;
101 if (shuffle_files[i] == 0) {
102 start_index = 0;
103 current_size = shard_sample_cout[shuffle_files[i]];
104 } else {
105 start_index = shard_sample_cout[shuffle_files[i] - 1];
106 current_size = shard_sample_cout[shuffle_files[i]] - start_index;
107 }
108 (void)std::copy(original_permutation.begin() + start_index,
109 original_permutation.begin() + start_index + current_size,
110 tasks.permutation_.begin() + whole_index);
111 whole_index += current_size;
112 }
113
114 auto total_no = tasks.Size();
115 int64_t samples_to_assign =
116 (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
117 ShardTaskList new_tasks;
118 for (int64_t i = 0; i < samples_to_assign; ++i) {
119 new_tasks.AssignTask(tasks, tasks.permutation_[i]);
120 }
121 ShardTaskList::TaskListSwap(tasks, new_tasks);
122 return Status::OK();
123 }
124
ShuffleInfile(ShardTaskList & tasks)125 Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
126 if (no_of_samples_ == 0) {
127 no_of_samples_ = tasks.Size();
128 }
129 CHECK_FAIL_RETURN_UNEXPECTED_MR(
130 no_of_samples_ > 0, "Invalid input, 'num_samples' should be positive but got: " + std::to_string(no_of_samples_));
131 // reconstruct the permutation in file
132 // -- before --
133 // file1: [0, 1, 2]
134 // file2: [3, 4, 5, 6]
135 // file3: [7, 8]
136 // file4: [9, 10]
137 // files: [file1, file2, file3, file4]
138 // permutation: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
139 // -- after --
140 // permutation: [2, 0, 1, 4, 6, 3, 5, 8, 7, 9, 10]
141 auto shard_sample_cout = GetShardSampleCount();
142 int64_t start_index = 0;
143 for (int64_t i = 0; i < shard_sample_cout.size(); i++) {
144 auto current_size = shard_sample_cout[i] - start_index;
145 std::shuffle(tasks.permutation_.begin() + start_index, tasks.permutation_.begin() + start_index + current_size,
146 std::default_random_engine(shuffle_seed_));
147 start_index = shard_sample_cout[i];
148 }
149 auto total_no = tasks.Size();
150 ShardTaskList new_tasks;
151 int64_t samples_to_assign =
152 (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
153 for (int64_t i = 0; i < samples_to_assign; ++i) {
154 new_tasks.AssignTask(tasks, tasks.permutation_[i]);
155 }
156 ShardTaskList::TaskListSwap(tasks, new_tasks);
157 return Status::OK();
158 }
159
Execute(ShardTaskList & tasks)160 Status ShardShuffle::Execute(ShardTaskList &tasks) {
161 if (reshuffle_each_epoch_) {
162 shuffle_seed_++;
163 }
164 CHECK_FAIL_RETURN_UNEXPECTED_MR(tasks.categories >= 1,
165 "[Internal ERROR] task categories should be greater than or equal to 1 but got: " +
166 std::to_string(tasks.categories));
167 if (tasks.load_mode_ != LoadMode::kSlow) {
168 if (shuffle_type_ == kShuffleSample) { // shuffle each sample
169 if (tasks.permutation_.empty() == true) {
170 tasks.MakePerm();
171 }
172 if (GetShuffleMode() == dataset::ShuffleMode::kGlobal) {
173 if (replacement_ == true) {
174 ShardTaskList new_tasks;
175 if (no_of_samples_ == 0) {
176 no_of_samples_ = tasks.sample_ids_.size();
177 }
178 CHECK_FAIL_RETURN_UNEXPECTED_MR(
179 no_of_samples_ > 0,
180 "Invalid input, 'num_samples' should be positive but got: " + std::to_string(no_of_samples_));
181 for (uint32_t i = 0; i < no_of_samples_; ++i) {
182 new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
183 }
184
185 ShardTaskList::TaskListSwap(tasks, new_tasks);
186 } else {
187 std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
188 auto total_no = tasks.Size();
189 ShardTaskList new_tasks;
190 int64_t samples_to_assign =
191 (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
192 for (int64_t i = 0; i < samples_to_assign; ++i) {
193 new_tasks.AssignTask(tasks, tasks.permutation_[i]);
194 }
195 ShardTaskList::TaskListSwap(tasks, new_tasks);
196 }
197 } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) {
198 RETURN_IF_NOT_OK_MR(ShuffleInfile(tasks));
199 } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) {
200 RETURN_IF_NOT_OK_MR(ShuffleFiles(tasks));
201 }
202 } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
203 return this->CategoryShuffle(tasks);
204 }
205 } else {
206 // just shuffle file names
207 CHECK_FAIL_RETURN_UNEXPECTED_MR(
208 GetShuffleMode() != dataset::ShuffleMode::kInfile && GetShuffleMode() != dataset::ShuffleMode::kFiles,
209 "The shuffle mode Shuffle.FILES and Shuffle.INFILE are not supported because "
210 "the number of samples in the dataset is greater than " +
211 std::to_string(SLOW_LOAD_THRESHOLD) + ".");
212 auto shard_sample_count = GetShardSampleCount();
213 std::vector<int32_t> file_ids;
214 for (int32_t i = 0; i < shard_sample_count.size(); i++) {
215 file_ids.push_back(i);
216 }
217 std::shuffle(file_ids.begin(), file_ids.end(), std::default_random_engine(shuffle_seed_));
218 tasks.SetFileIds(file_ids);
219 tasks.need_shuffle_ = true;
220 tasks.shuffle_seed_ = shuffle_seed_;
221 int64_t samples_to_assign =
222 (no_of_samples_ > 0 && no_of_samples_ < tasks.SizeAfterSampling()) ? no_of_samples_ : tasks.SizeAfterSampling();
223 tasks.UpdatePartitionedShardSampleCountByNumSamples(samples_to_assign);
224 }
225 return Status::OK();
226 }
227 } // namespace mindrecord
228 } // namespace mindspore
229