1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_sample.h"
18
19 using mindspore::LogStream;
20 using mindspore::ExceptionType::NoExceptionType;
21 using mindspore::MsLogLevel::ERROR;
22
23 namespace mindspore {
24 namespace mindrecord {
ShardSample(int n)25 ShardSample::ShardSample(int n)
26 : numerator_(0),
27 denominator_(0),
28 partition_id_(0),
29 no_of_samples_(n),
30 indices_({}),
31 sampler_type_(kCustomTopNSampler),
32 offset_(-1) {}
33
ShardSample(int num,int den)34 ShardSample::ShardSample(int num, int den)
35 : numerator_(num),
36 denominator_(den),
37 partition_id_(0),
38 no_of_samples_(0),
39 indices_({}),
40 sampler_type_(kCustomTopPercentSampler),
41 offset_(-1) {}
42
ShardSample(int num,int den,int par,int no_of_samples,int offset)43 ShardSample::ShardSample(int num, int den, int par, int no_of_samples, int offset)
44 : numerator_(num),
45 denominator_(den),
46 partition_id_(par),
47 no_of_samples_(no_of_samples),
48 indices_({}),
49 sampler_type_(kCustomTopPercentSampler),
50 offset_(offset) {}
51
ShardSample(const std::vector<int64_t> & indices)52 ShardSample::ShardSample(const std::vector<int64_t> &indices)
53 : numerator_(0),
54 denominator_(0),
55 partition_id_(0),
56 no_of_samples_(0),
57 indices_(indices),
58 sampler_type_(kSubsetSampler) {}
59
ShardSample(const std::vector<int64_t> & indices,uint32_t seed)60 ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed) : ShardSample(indices) {
61 sampler_type_ = kSubsetRandomSampler;
62 shuffle_op_ = std::make_shared<ShardShuffle>(seed);
63 }
64
GetNumSamples(int64_t dataset_size,int64_t num_classes)65 int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
66 if (sampler_type_ == kCustomTopNSampler) {
67 return no_of_samples_;
68 }
69
70 if (sampler_type_ == kCustomTopPercentSampler) {
71 if (dataset_size % denominator_ == 0) {
72 return dataset_size / denominator_ * numerator_;
73 } else {
74 return dataset_size / denominator_ * numerator_ + 1;
75 }
76 }
77 if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
78 return indices_.size();
79 }
80 return 0;
81 }
82
UpdateTasks(ShardTaskList & tasks,int taking)83 Status ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
84 if (tasks.permutation_.empty()) {
85 ShardTaskList new_tasks;
86 int total_no = static_cast<int>(tasks.sample_ids_.size());
87 if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
88 for (int i = 0; i < indices_.size(); ++i) {
89 int index = ((indices_[i] % total_no) + total_no) % total_no;
90 new_tasks.AssignTask(tasks, index); // different mod result between c and python
91 }
92 } else {
93 int count = 0;
94 if (nums_per_shard_.empty()) {
95 for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
96 if (no_of_samples_ != 0 && count == no_of_samples_) break;
97 new_tasks.AssignTask(tasks, i % total_no); // rounding up. if overflow, go back to start
98 count++;
99 }
100 } else {
101 // Get samples within a specific range
102 size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0;
103 for (; i < nums_per_shard_[partition_id_]; i++) {
104 if (no_of_samples_ != 0 && count == no_of_samples_) break;
105 new_tasks.AssignTask(tasks, i % total_no);
106 count++;
107 }
108 }
109 }
110 ShardTaskList::TaskListSwap(tasks, new_tasks);
111 } else {
112 ShardTaskList new_tasks;
113 int total_no = static_cast<int>(tasks.permutation_.size());
114 int cnt = 0;
115 for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
116 if (no_of_samples_ != 0 && cnt == no_of_samples_) break;
117 new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
118 cnt++;
119 }
120 ShardTaskList::TaskListSwap(tasks, new_tasks);
121 }
122 return Status::OK();
123 }
124
Execute(ShardTaskList & tasks)125 Status ShardSample::Execute(ShardTaskList &tasks) {
126 if (offset_ != -1) {
127 int64_t old_v = 0;
128 int num_rows_ = static_cast<int>(tasks.sample_ids_.size());
129 for (int x = 0; x < denominator_; x++) {
130 int samples_per_buffer_ = (num_rows_ + offset_) / denominator_;
131 int remainder = (num_rows_ + offset_) % denominator_;
132 if (x < remainder) samples_per_buffer_++;
133 if (x < offset_) samples_per_buffer_--;
134 old_v += samples_per_buffer_;
135 // nums_per_shard_ is used to save the current shard's ending index
136 nums_per_shard_.push_back(old_v);
137 }
138 }
139 int no_of_categories = static_cast<int>(tasks.categories);
140 int total_no = static_cast<int>(tasks.sample_ids_.size());
141 int taking = 0;
142 if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1
143 no_of_samples_ = std::min(no_of_samples_, total_no);
144 taking = no_of_samples_ - no_of_samples_ % no_of_categories;
145 } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
146 CHECK_FAIL_RETURN_UNEXPECTED(indices_.size() <= static_cast<size_t>(total_no),
147 "Invalid input, indices size: " + std::to_string(indices_.size()) +
148 " need to be less than dataset size: " + std::to_string(total_no) + ".");
149 } else { // constructor TopPercent
150 if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
151 if (numerator_ == 1 && denominator_ > 1) { // sharding
152 taking = (total_no + denominator_ - 1) / denominator_;
153 } else { // non sharding
154 taking = total_no * numerator_ / denominator_;
155 taking -= (taking % no_of_categories);
156 }
157 } else {
158 RETURN_STATUS_UNEXPECTED("Invalid input, numerator: " + std::to_string(numerator_) +
159 " need to be positive and be less than denominator: " + std::to_string(denominator_) +
160 ".");
161 }
162 }
163 return UpdateTasks(tasks, taking);
164 }
165
SufExecute(ShardTaskList & tasks)166 Status ShardSample::SufExecute(ShardTaskList &tasks) {
167 if (sampler_type_ == kSubsetRandomSampler) {
168 RETURN_IF_NOT_OK((*shuffle_op_)(tasks));
169 }
170 return Status::OK();
171 }
172 } // namespace mindrecord
173 } // namespace mindspore
174