• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_shuffle.h"
18 
19 #include <algorithm>
20 
21 namespace mindspore {
22 namespace mindrecord {
ShardShuffle(uint32_t seed,ShuffleType shuffle_type)23 ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
24     : shuffle_seed_(seed),
25       no_of_samples_(0),
26       replacement_(false),
27       reshuffle_each_epoch_(true),
28       shuffle_type_(shuffle_type) {}
29 
ShardShuffle(uint32_t seed,int64_t no_of_samples,bool replacement,bool reshuffle_each_epoch,ShuffleType shuffle_type)30 ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
31                            ShuffleType shuffle_type)
32     : shuffle_seed_(seed),
33       no_of_samples_(no_of_samples),
34       replacement_(replacement),
35       reshuffle_each_epoch_(reshuffle_each_epoch),
36       shuffle_type_(shuffle_type) {}
37 
GetNumSamples(int64_t dataset_size,int64_t num_classes)38 int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
39   if (replacement_) {
40     return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
41   }
42   return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
43 }
44 
CategoryShuffle(ShardTaskList & tasks)45 Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
46   uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories;
47   std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
48   for (uint32_t i = 0; i < tasks.categories; i++) {
49     for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
50     std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
51   }
52   tasks.permutation_.clear();
53   for (uint32_t j = 0; j < individual_size; j++) {
54     for (uint32_t i = 0; i < tasks.categories; i++) {
55       tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
56     }
57   }
58 
59   ShardTaskList new_tasks;
60   for (size_t i = 0; i < individual_size; ++i) {
61     new_tasks.AssignTask(tasks, tasks.permutation_[i]);
62   }
63   ShardTaskList::TaskListSwap(tasks, new_tasks);
64 
65   return Status::OK();
66 }
67 
ShuffleFiles(ShardTaskList & tasks)68 Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
69   if (no_of_samples_ == 0) {
70     no_of_samples_ = static_cast<int>(tasks.Size());
71   }
72   CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
73                                                      std::to_string(no_of_samples_) + "] need to be positive.");
74   auto shard_sample_cout = GetShardSampleCount();
75 
76   // shuffle the files index
77   std::vector<uint32_t> shuffle_files;
78   for (uint32_t i = 0; i < shard_sample_cout.size(); i++) {
79     shuffle_files.push_back(i);
80   }
81   std::shuffle(shuffle_files.begin(), shuffle_files.end(), std::default_random_engine(shuffle_seed_));
82 
83   // reconstruct the permutation between files
84   // -- before --
85   // file1: [0, 1, 2]
86   // file2: [3, 4, 5, 6]
87   // file3: [7, 8]
88   // file4: [9, 10]
89   // files: [file1, file2, file3, file4]
90   // permutation: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
91   // -- after --
92   // files: [file4, file1, file3, file2]
93   // permutation : [9, 10, 0, 1, 2, 7, 8, 3, 4, 5, 6]
94   auto original_permutation = tasks.permutation_;
95   uint32_t whole_index = 0;
96   for (uint32_t i = 0; i < shuffle_files.size(); i++) {
97     uint32_t start_index = 0;
98     uint32_t current_size = 0;
99     if (shuffle_files[i] == 0) {
100       start_index = 0;
101       current_size = shard_sample_cout[shuffle_files[i]];
102     } else {
103       start_index = shard_sample_cout[shuffle_files[i] - 1];
104       current_size = shard_sample_cout[shuffle_files[i]] - start_index;
105     }
106     std::copy(original_permutation.begin() + start_index, original_permutation.begin() + start_index + current_size,
107               tasks.permutation_.begin() + whole_index);
108     whole_index += current_size;
109   }
110 
111   auto total_no = static_cast<int64_t>(tasks.Size());
112   size_t samples_to_assign =
113     (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
114   ShardTaskList new_tasks;
115   for (size_t i = 0; i < samples_to_assign; ++i) {
116     new_tasks.AssignTask(tasks, tasks.permutation_[i]);
117   }
118   ShardTaskList::TaskListSwap(tasks, new_tasks);
119   return Status::OK();
120 }
121 
ShuffleInfile(ShardTaskList & tasks)122 Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
123   if (no_of_samples_ == 0) {
124     no_of_samples_ = static_cast<int>(tasks.Size());
125   }
126   CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
127                                                      std::to_string(no_of_samples_) + "] need to be positive.");
128   // reconstruct the permutation in file
129   // -- before --
130   // file1: [0, 1, 2]
131   // file2: [3, 4, 5, 6]
132   // file3: [7, 8]
133   // file4: [9, 10]
134   // files: [file1, file2, file3, file4]
135   // permutation: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
136   // -- after --
137   // permutation: [2, 0, 1, 4, 6, 3, 5, 8, 7, 9, 10]
138   auto shard_sample_cout = GetShardSampleCount();
139   uint32_t start_index = 0;
140   for (uint32_t i = 0; i < shard_sample_cout.size(); i++) {
141     auto current_size = shard_sample_cout[i] - start_index;
142     std::shuffle(tasks.permutation_.begin() + start_index, tasks.permutation_.begin() + start_index + current_size,
143                  std::default_random_engine(shuffle_seed_));
144     start_index = shard_sample_cout[i];
145   }
146   auto total_no = static_cast<int64_t>(tasks.Size());
147   ShardTaskList new_tasks;
148   size_t samples_to_assign =
149     (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
150   for (size_t i = 0; i < samples_to_assign; ++i) {
151     new_tasks.AssignTask(tasks, tasks.permutation_[i]);
152   }
153   ShardTaskList::TaskListSwap(tasks, new_tasks);
154   return Status::OK();
155 }
156 
Execute(ShardTaskList & tasks)157 Status ShardShuffle::Execute(ShardTaskList &tasks) {
158   if (reshuffle_each_epoch_) {
159     shuffle_seed_++;
160   }
161   CHECK_FAIL_RETURN_UNEXPECTED(
162     tasks.categories >= 1,
163     "Invalid data, task categories [" + std::to_string(tasks.categories) + "] need to be larger than 1.");
164   if (shuffle_type_ == kShuffleSample) {  // shuffle each sample
165     if (tasks.permutation_.empty() == true) {
166       tasks.MakePerm();
167     }
168     if (GetShuffleMode() == dataset::ShuffleMode::kGlobal) {
169       if (replacement_ == true) {
170         ShardTaskList new_tasks;
171         if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
172         CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
173                                                            std::to_string(no_of_samples_) + "] need to be positive.");
174         for (uint32_t i = 0; i < no_of_samples_; ++i) {
175           new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
176         }
177 
178         ShardTaskList::TaskListSwap(tasks, new_tasks);
179       } else {
180         std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
181         auto total_no = static_cast<int64_t>(tasks.Size());
182         ShardTaskList new_tasks;
183         size_t samples_to_assign =
184           (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
185         for (size_t i = 0; i < samples_to_assign; ++i) {
186           new_tasks.AssignTask(tasks, tasks.permutation_[i]);
187         }
188         ShardTaskList::TaskListSwap(tasks, new_tasks);
189       }
190     } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) {
191       RETURN_IF_NOT_OK(ShuffleInfile(tasks));
192     } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) {
193       RETURN_IF_NOT_OK(ShuffleFiles(tasks));
194     }
195   } else {  // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
196     return this->CategoryShuffle(tasks);
197   }
198   return Status::OK();
199 }
200 }  // namespace mindrecord
201 }  // namespace mindspore
202