• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/util/random.h"
18 #include "minddata/mindrecord/include/shard_task_list.h"
19 #include "utils/ms_utils.h"
20 #include "minddata/mindrecord/include/common/shard_utils.h"
21 
22 using mindspore::LogStream;
23 using mindspore::ExceptionType::NoExceptionType;
24 using mindspore::MsLogLevel::DEBUG;
25 
26 namespace mindspore {
27 namespace mindrecord {
ShardTaskList()28 ShardTaskList::ShardTaskList() : categories(1) {}
29 
ShardTaskList(const ShardTaskList & other)30 ShardTaskList::ShardTaskList(const ShardTaskList &other)
31     : categories(other.categories),
32       permutation_(other.permutation_),
33       sample_ids_(other.sample_ids_),
34       task_list_(other.task_list_) {}
35 
operator =(const ShardTaskList & other)36 ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
37   ShardTaskList tmp(other);
38   std::swap(categories, tmp.categories);
39   permutation_.swap(tmp.permutation_);
40   sample_ids_.swap(tmp.sample_ids_);
41   task_list_.swap(tmp.task_list_);
42   return *this;
43 }
44 
InitSampleIds()45 void ShardTaskList::InitSampleIds() {
46   // no-op if there already exists sample ids.  Do not clobber previous list
47   if (sample_ids_.empty()) {
48     sample_ids_ = std::vector<int>(task_list_.size());
49     for (int i = 0; i < task_list_.size(); i++) sample_ids_[i] = i;
50   }
51 }
52 
MakePerm()53 void ShardTaskList::MakePerm() {
54   size_t perm_size = sample_ids_.size();
55   permutation_ = std::vector<int>(perm_size);
56   for (uint32_t i = 0; i < perm_size; i++) {
57     permutation_[i] = static_cast<int>(i);
58   }
59 }
60 
61 // Swap the new_tasks with orig_tasks
TaskListSwap(ShardTaskList & orig_tasks,ShardTaskList & new_tasks)62 void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) {
63   // When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a
64   // new_tasks that does not have those fields will result in clobbering/losing the data after the swap.
65   // The task_list_ should not be lost/clobbered.
66   // This function can be called in the middle of mindrecord's epoch, when orig_tasks.task_list_ is still being
67   // used by mindrecord op's worker threads. So don't touch its task_list_ since this field should be preserved anyways.
68 
69   std::swap(orig_tasks.categories, new_tasks.categories);
70   std::swap(orig_tasks.permutation_, new_tasks.permutation_);
71   std::swap(orig_tasks.sample_ids_, new_tasks.sample_ids_);
72 }
73 
PopBack()74 void ShardTaskList::PopBack() { task_list_.pop_back(); }
75 
Size() const76 uint32_t ShardTaskList::Size() const { return static_cast<uint32_t>(task_list_.size()); }
77 
SizeOfRows() const78 uint32_t ShardTaskList::SizeOfRows() const {
79   if (task_list_.size() == 0) return static_cast<uint32_t>(0);
80 
81   // 1 task is 1 page
82   const size_t kBlobInfoIndex = 2;
83   auto sum_num_rows = [](int x, ShardTask y) { return x + std::get<kBlobInfoIndex>(y)[0]; };
84   uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows);
85   return nRows;
86 }
87 
GetTaskByID(size_t id)88 ShardTask &ShardTaskList::GetTaskByID(size_t id) { return task_list_[id]; }
89 
GetTaskSampleByID(size_t id)90 int ShardTaskList::GetTaskSampleByID(size_t id) { return sample_ids_[id]; }
91 
GetRandomTaskID()92 int ShardTaskList::GetRandomTaskID() {
93   std::mt19937 gen = mindspore::dataset::GetRandomDevice();
94   std::uniform_int_distribution<> dis(0, sample_ids_.size() - 1);
95   return dis(gen);
96 }
97 
GetRandomTask()98 ShardTask &ShardTaskList::GetRandomTask() {
99   std::mt19937 gen = mindspore::dataset::GetRandomDevice();
100   std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
101   return task_list_[dis(gen)];
102 }
103 
Combine(std::vector<ShardTaskList> & category_tasks,bool replacement,int64_t num_elements,int64_t num_samples)104 ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
105                                      int64_t num_samples) {
106   ShardTaskList res;
107   if (category_tasks.empty()) return res;
108   auto total_categories = category_tasks.size();
109   res.categories = static_cast<uint32_t>(total_categories);
110   if (replacement == false) {
111     auto minTasks = category_tasks[0].Size();
112     for (uint32_t i = 1; i < total_categories; i++) {
113       minTasks = std::min(minTasks, category_tasks[i].Size());
114     }
115     int64_t count = 0;
116     for (uint32_t task_no = 0; task_no < minTasks; task_no++) {
117       for (uint32_t i = 0; i < total_categories; i++) {
118         if (num_samples != 0 && count == num_samples) break;
119         res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no))));
120         count++;
121       }
122     }
123   } else {
124     auto maxTasks = category_tasks[0].Size();
125     for (uint32_t i = 1; i < total_categories; i++) {
126       maxTasks = std::max(maxTasks, category_tasks[i].Size());
127     }
128     if (num_elements != std::numeric_limits<int64_t>::max()) {
129       maxTasks = static_cast<decltype(maxTasks)>(num_elements);
130     }
131     int64_t count = 0;
132     for (uint32_t i = 0; i < total_categories; i++) {
133       for (uint32_t j = 0; j < maxTasks; j++) {
134         if (num_samples != 0 && count == num_samples) break;
135         res.InsertTask(category_tasks[i].GetRandomTask());
136         count++;
137       }
138     }
139   }
140 
141   return res;
142 }
143 }  // namespace mindrecord
144 }  // namespace mindspore
145