1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_shuffle.h"
18
19 #include <algorithm>
20
21 namespace mindspore {
22 namespace mindrecord {
ShardShuffle(uint32_t seed,ShuffleType shuffle_type)23 ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
24 : shuffle_seed_(seed),
25 no_of_samples_(0),
26 replacement_(false),
27 reshuffle_each_epoch_(true),
28 shuffle_type_(shuffle_type) {}
29
ShardShuffle(uint32_t seed,int64_t no_of_samples,bool replacement,bool reshuffle_each_epoch,ShuffleType shuffle_type)30 ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
31 ShuffleType shuffle_type)
32 : shuffle_seed_(seed),
33 no_of_samples_(no_of_samples),
34 replacement_(replacement),
35 reshuffle_each_epoch_(reshuffle_each_epoch),
36 shuffle_type_(shuffle_type) {}
37
GetNumSamples(int64_t dataset_size,int64_t num_classes)38 int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
39 if (replacement_) {
40 return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
41 }
42 return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
43 }
44
CategoryShuffle(ShardTaskList & tasks)45 Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
46 uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories;
47 std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
48 for (uint32_t i = 0; i < tasks.categories; i++) {
49 for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
50 std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
51 }
52 tasks.permutation_.clear();
53 for (uint32_t j = 0; j < individual_size; j++) {
54 for (uint32_t i = 0; i < tasks.categories; i++) {
55 tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
56 }
57 }
58
59 ShardTaskList new_tasks;
60 for (size_t i = 0; i < individual_size; ++i) {
61 new_tasks.AssignTask(tasks, tasks.permutation_[i]);
62 }
63 ShardTaskList::TaskListSwap(tasks, new_tasks);
64
65 return Status::OK();
66 }
67
ShuffleFiles(ShardTaskList & tasks)68 Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
69 if (no_of_samples_ == 0) {
70 no_of_samples_ = static_cast<int>(tasks.Size());
71 }
72 CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
73 std::to_string(no_of_samples_) + "] need to be positive.");
74 auto shard_sample_cout = GetShardSampleCount();
75
76 // shuffle the files index
77 std::vector<uint32_t> shuffle_files;
78 for (uint32_t i = 0; i < shard_sample_cout.size(); i++) {
79 shuffle_files.push_back(i);
80 }
81 std::shuffle(shuffle_files.begin(), shuffle_files.end(), std::default_random_engine(shuffle_seed_));
82
83 // reconstruct the permutation between files
84 // -- before --
85 // file1: [0, 1, 2]
86 // file2: [3, 4, 5, 6]
87 // file3: [7, 8]
88 // file4: [9, 10]
89 // files: [file1, file2, file3, file4]
90 // permutation: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
91 // -- after --
92 // files: [file4, file1, file3, file2]
93 // permutation : [9, 10, 0, 1, 2, 7, 8, 3, 4, 5, 6]
94 auto original_permutation = tasks.permutation_;
95 uint32_t whole_index = 0;
96 for (uint32_t i = 0; i < shuffle_files.size(); i++) {
97 uint32_t start_index = 0;
98 uint32_t current_size = 0;
99 if (shuffle_files[i] == 0) {
100 start_index = 0;
101 current_size = shard_sample_cout[shuffle_files[i]];
102 } else {
103 start_index = shard_sample_cout[shuffle_files[i] - 1];
104 current_size = shard_sample_cout[shuffle_files[i]] - start_index;
105 }
106 std::copy(original_permutation.begin() + start_index, original_permutation.begin() + start_index + current_size,
107 tasks.permutation_.begin() + whole_index);
108 whole_index += current_size;
109 }
110
111 auto total_no = static_cast<int64_t>(tasks.Size());
112 size_t samples_to_assign =
113 (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
114 ShardTaskList new_tasks;
115 for (size_t i = 0; i < samples_to_assign; ++i) {
116 new_tasks.AssignTask(tasks, tasks.permutation_[i]);
117 }
118 ShardTaskList::TaskListSwap(tasks, new_tasks);
119 return Status::OK();
120 }
121
ShuffleInfile(ShardTaskList & tasks)122 Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
123 if (no_of_samples_ == 0) {
124 no_of_samples_ = static_cast<int>(tasks.Size());
125 }
126 CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
127 std::to_string(no_of_samples_) + "] need to be positive.");
128 // reconstruct the permutation in file
129 // -- before --
130 // file1: [0, 1, 2]
131 // file2: [3, 4, 5, 6]
132 // file3: [7, 8]
133 // file4: [9, 10]
134 // files: [file1, file2, file3, file4]
135 // permutation: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
136 // -- after --
137 // permutation: [2, 0, 1, 4, 6, 3, 5, 8, 7, 9, 10]
138 auto shard_sample_cout = GetShardSampleCount();
139 uint32_t start_index = 0;
140 for (uint32_t i = 0; i < shard_sample_cout.size(); i++) {
141 auto current_size = shard_sample_cout[i] - start_index;
142 std::shuffle(tasks.permutation_.begin() + start_index, tasks.permutation_.begin() + start_index + current_size,
143 std::default_random_engine(shuffle_seed_));
144 start_index = shard_sample_cout[i];
145 }
146 auto total_no = static_cast<int64_t>(tasks.Size());
147 ShardTaskList new_tasks;
148 size_t samples_to_assign =
149 (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
150 for (size_t i = 0; i < samples_to_assign; ++i) {
151 new_tasks.AssignTask(tasks, tasks.permutation_[i]);
152 }
153 ShardTaskList::TaskListSwap(tasks, new_tasks);
154 return Status::OK();
155 }
156
Execute(ShardTaskList & tasks)157 Status ShardShuffle::Execute(ShardTaskList &tasks) {
158 if (reshuffle_each_epoch_) {
159 shuffle_seed_++;
160 }
161 CHECK_FAIL_RETURN_UNEXPECTED(
162 tasks.categories >= 1,
163 "Invalid data, task categories [" + std::to_string(tasks.categories) + "] need to be larger than 1.");
164 if (shuffle_type_ == kShuffleSample) { // shuffle each sample
165 if (tasks.permutation_.empty() == true) {
166 tasks.MakePerm();
167 }
168 if (GetShuffleMode() == dataset::ShuffleMode::kGlobal) {
169 if (replacement_ == true) {
170 ShardTaskList new_tasks;
171 if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
172 CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
173 std::to_string(no_of_samples_) + "] need to be positive.");
174 for (uint32_t i = 0; i < no_of_samples_; ++i) {
175 new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
176 }
177
178 ShardTaskList::TaskListSwap(tasks, new_tasks);
179 } else {
180 std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
181 auto total_no = static_cast<int64_t>(tasks.Size());
182 ShardTaskList new_tasks;
183 size_t samples_to_assign =
184 (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
185 for (size_t i = 0; i < samples_to_assign; ++i) {
186 new_tasks.AssignTask(tasks, tasks.permutation_[i]);
187 }
188 ShardTaskList::TaskListSwap(tasks, new_tasks);
189 }
190 } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) {
191 RETURN_IF_NOT_OK(ShuffleInfile(tasks));
192 } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) {
193 RETURN_IF_NOT_OK(ShuffleFiles(tasks));
194 }
195 } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
196 return this->CategoryShuffle(tasks);
197 }
198 return Status::OK();
199 }
200 } // namespace mindrecord
201 } // namespace mindspore
202