1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/util/random.h"
18 #include "minddata/mindrecord/include/shard_task_list.h"
19 #include "utils/ms_utils.h"
20 #include "minddata/mindrecord/include/common/shard_utils.h"
21
22 using mindspore::LogStream;
23 using mindspore::ExceptionType::NoExceptionType;
24 using mindspore::MsLogLevel::DEBUG;
25
26 namespace mindspore {
27 namespace mindrecord {
ShardTaskList()28 ShardTaskList::ShardTaskList() : categories(1) {}
29
ShardTaskList(const ShardTaskList & other)30 ShardTaskList::ShardTaskList(const ShardTaskList &other)
31 : categories(other.categories),
32 permutation_(other.permutation_),
33 sample_ids_(other.sample_ids_),
34 task_list_(other.task_list_) {}
35
operator =(const ShardTaskList & other)36 ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
37 ShardTaskList tmp(other);
38 std::swap(categories, tmp.categories);
39 permutation_.swap(tmp.permutation_);
40 sample_ids_.swap(tmp.sample_ids_);
41 task_list_.swap(tmp.task_list_);
42 return *this;
43 }
44
InitSampleIds()45 void ShardTaskList::InitSampleIds() {
46 // no-op if there already exists sample ids. Do not clobber previous list
47 if (sample_ids_.empty()) {
48 sample_ids_ = std::vector<int>(task_list_.size());
49 for (int i = 0; i < task_list_.size(); i++) sample_ids_[i] = i;
50 }
51 }
52
MakePerm()53 void ShardTaskList::MakePerm() {
54 size_t perm_size = sample_ids_.size();
55 permutation_ = std::vector<int>(perm_size);
56 for (uint32_t i = 0; i < perm_size; i++) {
57 permutation_[i] = static_cast<int>(i);
58 }
59 }
60
61 // Swap the new_tasks with orig_tasks
TaskListSwap(ShardTaskList & orig_tasks,ShardTaskList & new_tasks)62 void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) {
63 // When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a
64 // new_tasks that does not have those fields will result in clobbering/losing the data after the swap.
65 // The task_list_ should not be lost/clobbered.
66 // This function can be called in the middle of mindrecord's epoch, when orig_tasks.task_list_ is still being
67 // used by mindrecord op's worker threads. So don't touch its task_list_ since this field should be preserved anyways.
68
69 std::swap(orig_tasks.categories, new_tasks.categories);
70 std::swap(orig_tasks.permutation_, new_tasks.permutation_);
71 std::swap(orig_tasks.sample_ids_, new_tasks.sample_ids_);
72 }
73
PopBack()74 void ShardTaskList::PopBack() { task_list_.pop_back(); }
75
Size() const76 uint32_t ShardTaskList::Size() const { return static_cast<uint32_t>(task_list_.size()); }
77
SizeOfRows() const78 uint32_t ShardTaskList::SizeOfRows() const {
79 if (task_list_.size() == 0) return static_cast<uint32_t>(0);
80
81 // 1 task is 1 page
82 const size_t kBlobInfoIndex = 2;
83 auto sum_num_rows = [](int x, ShardTask y) { return x + std::get<kBlobInfoIndex>(y)[0]; };
84 uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows);
85 return nRows;
86 }
87
GetTaskByID(size_t id)88 ShardTask &ShardTaskList::GetTaskByID(size_t id) { return task_list_[id]; }
89
GetTaskSampleByID(size_t id)90 int ShardTaskList::GetTaskSampleByID(size_t id) { return sample_ids_[id]; }
91
GetRandomTaskID()92 int ShardTaskList::GetRandomTaskID() {
93 std::mt19937 gen = mindspore::dataset::GetRandomDevice();
94 std::uniform_int_distribution<> dis(0, sample_ids_.size() - 1);
95 return dis(gen);
96 }
97
GetRandomTask()98 ShardTask &ShardTaskList::GetRandomTask() {
99 std::mt19937 gen = mindspore::dataset::GetRandomDevice();
100 std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
101 return task_list_[dis(gen)];
102 }
103
Combine(std::vector<ShardTaskList> & category_tasks,bool replacement,int64_t num_elements,int64_t num_samples)104 ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
105 int64_t num_samples) {
106 ShardTaskList res;
107 if (category_tasks.empty()) return res;
108 auto total_categories = category_tasks.size();
109 res.categories = static_cast<uint32_t>(total_categories);
110 if (replacement == false) {
111 auto minTasks = category_tasks[0].Size();
112 for (uint32_t i = 1; i < total_categories; i++) {
113 minTasks = std::min(minTasks, category_tasks[i].Size());
114 }
115 int64_t count = 0;
116 for (uint32_t task_no = 0; task_no < minTasks; task_no++) {
117 for (uint32_t i = 0; i < total_categories; i++) {
118 if (num_samples != 0 && count == num_samples) break;
119 res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no))));
120 count++;
121 }
122 }
123 } else {
124 auto maxTasks = category_tasks[0].Size();
125 for (uint32_t i = 1; i < total_categories; i++) {
126 maxTasks = std::max(maxTasks, category_tasks[i].Size());
127 }
128 if (num_elements != std::numeric_limits<int64_t>::max()) {
129 maxTasks = static_cast<decltype(maxTasks)>(num_elements);
130 }
131 int64_t count = 0;
132 for (uint32_t i = 0; i < total_categories; i++) {
133 for (uint32_t j = 0; j < maxTasks; j++) {
134 if (num_samples != 0 && count == num_samples) break;
135 res.InsertTask(category_tasks[i].GetRandomTask());
136 count++;
137 }
138 }
139 }
140
141 return res;
142 }
143 } // namespace mindrecord
144 } // namespace mindspore
145