• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_sample.h"
18 
19 using mindspore::LogStream;
20 using mindspore::ExceptionType::NoExceptionType;
21 using mindspore::MsLogLevel::ERROR;
22 
23 namespace mindspore {
24 namespace mindrecord {
ShardSample(int n)25 ShardSample::ShardSample(int n)
26     : numerator_(0),
27       denominator_(0),
28       partition_id_(0),
29       no_of_samples_(n),
30       indices_({}),
31       sampler_type_(kCustomTopNSampler),
32       offset_(-1) {}
33 
ShardSample(int num,int den)34 ShardSample::ShardSample(int num, int den)
35     : numerator_(num),
36       denominator_(den),
37       partition_id_(0),
38       no_of_samples_(0),
39       indices_({}),
40       sampler_type_(kCustomTopPercentSampler),
41       offset_(-1) {}
42 
ShardSample(int num,int den,int par,int no_of_samples,int offset)43 ShardSample::ShardSample(int num, int den, int par, int no_of_samples, int offset)
44     : numerator_(num),
45       denominator_(den),
46       partition_id_(par),
47       no_of_samples_(no_of_samples),
48       indices_({}),
49       sampler_type_(kCustomTopPercentSampler),
50       offset_(offset) {}
51 
ShardSample(const std::vector<int64_t> & indices)52 ShardSample::ShardSample(const std::vector<int64_t> &indices)
53     : numerator_(0),
54       denominator_(0),
55       partition_id_(0),
56       no_of_samples_(0),
57       indices_(indices),
58       sampler_type_(kSubsetSampler) {}
59 
ShardSample(const std::vector<int64_t> & indices,uint32_t seed)60 ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed) : ShardSample(indices) {
61   sampler_type_ = kSubsetRandomSampler;
62   shuffle_op_ = std::make_shared<ShardShuffle>(seed);
63 }
64 
GetNumSamples(int64_t dataset_size,int64_t num_classes)65 int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
66   if (sampler_type_ == kCustomTopNSampler) {
67     return no_of_samples_;
68   }
69 
70   if (sampler_type_ == kCustomTopPercentSampler) {
71     if (dataset_size % denominator_ == 0) {
72       return dataset_size / denominator_ * numerator_;
73     } else {
74       return dataset_size / denominator_ * numerator_ + 1;
75     }
76   }
77   if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
78     return indices_.size();
79   }
80   return 0;
81 }
82 
UpdateTasks(ShardTaskList & tasks,int taking)83 Status ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
84   if (tasks.permutation_.empty()) {
85     ShardTaskList new_tasks;
86     int total_no = static_cast<int>(tasks.sample_ids_.size());
87     if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
88       for (int i = 0; i < indices_.size(); ++i) {
89         int index = ((indices_[i] % total_no) + total_no) % total_no;
90         new_tasks.AssignTask(tasks, index);  // different mod result between c and python
91       }
92     } else {
93       int count = 0;
94       if (nums_per_shard_.empty()) {
95         for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
96           if (no_of_samples_ != 0 && count == no_of_samples_) break;
97           new_tasks.AssignTask(tasks, i % total_no);  // rounding up. if overflow, go back to start
98           count++;
99         }
100       } else {
101         // Get samples within a specific range
102         size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0;
103         for (; i < nums_per_shard_[partition_id_]; i++) {
104           if (no_of_samples_ != 0 && count == no_of_samples_) break;
105           new_tasks.AssignTask(tasks, i % total_no);
106           count++;
107         }
108       }
109     }
110     ShardTaskList::TaskListSwap(tasks, new_tasks);
111   } else {
112     ShardTaskList new_tasks;
113     int total_no = static_cast<int>(tasks.permutation_.size());
114     int cnt = 0;
115     for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
116       if (no_of_samples_ != 0 && cnt == no_of_samples_) break;
117       new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
118       cnt++;
119     }
120     ShardTaskList::TaskListSwap(tasks, new_tasks);
121   }
122   return Status::OK();
123 }
124 
Execute(ShardTaskList & tasks)125 Status ShardSample::Execute(ShardTaskList &tasks) {
126   if (offset_ != -1) {
127     int64_t old_v = 0;
128     int num_rows_ = static_cast<int>(tasks.sample_ids_.size());
129     for (int x = 0; x < denominator_; x++) {
130       int samples_per_buffer_ = (num_rows_ + offset_) / denominator_;
131       int remainder = (num_rows_ + offset_) % denominator_;
132       if (x < remainder) samples_per_buffer_++;
133       if (x < offset_) samples_per_buffer_--;
134       old_v += samples_per_buffer_;
135       // nums_per_shard_ is used to save the current shard's ending index
136       nums_per_shard_.push_back(old_v);
137     }
138   }
139   int no_of_categories = static_cast<int>(tasks.categories);
140   int total_no = static_cast<int>(tasks.sample_ids_.size());
141   int taking = 0;
142   if (sampler_type_ == kCustomTopNSampler) {  // non sharding case constructor #1
143     no_of_samples_ = std::min(no_of_samples_, total_no);
144     taking = no_of_samples_ - no_of_samples_ % no_of_categories;
145   } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
146     CHECK_FAIL_RETURN_UNEXPECTED(indices_.size() <= static_cast<size_t>(total_no),
147                                  "Invalid input, indices size: " + std::to_string(indices_.size()) +
148                                    " need to be less than  dataset size: " + std::to_string(total_no) + ".");
149   } else {  // constructor TopPercent
150     if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
151       if (numerator_ == 1 && denominator_ > 1) {  // sharding
152         taking = (total_no + denominator_ - 1) / denominator_;
153       } else {  // non sharding
154         taking = total_no * numerator_ / denominator_;
155         taking -= (taking % no_of_categories);
156       }
157     } else {
158       RETURN_STATUS_UNEXPECTED("Invalid input, numerator: " + std::to_string(numerator_) +
159                                " need to be positive and be less than denominator: " + std::to_string(denominator_) +
160                                ".");
161     }
162   }
163   return UpdateTasks(tasks, taking);
164 }
165 
SufExecute(ShardTaskList & tasks)166 Status ShardSample::SufExecute(ShardTaskList &tasks) {
167   if (sampler_type_ == kSubsetRandomSampler) {
168     RETURN_IF_NOT_OK((*shuffle_op_)(tasks));
169   }
170   return Status::OK();
171 }
172 }  // namespace mindrecord
173 }  // namespace mindspore
174