1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
19
20 #include <algorithm>
21 #include <iostream>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 #include "minddata/mindrecord/include/common/shard_utils.h"
27 #include "minddata/mindrecord/include/mindrecord_macro.h"
28
29 namespace mindspore {
30 namespace mindrecord {
31
32 // The data struct is as below:
33 // 1. TaskType: kCommonTask / kPaddedTask
34 // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
35 // 3. std::vector<uint64_t> : [blob_start, blob_end]
36 // 4. json : scalar_variable_fields
37 using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>;
38
39 // The data struct is as below:
40 // 1. TaskType: kCommonTask / kPaddedTask
41 // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
42 using TaskInfo = std::tuple<TaskType, std::tuple<int, int>>;
43
44 // The data struct is as below: contain the meta info
45 // 3. std::vector<uint64_t> : [blob_start, blob_end]
46 // 4. json : scalar_variable_fields
47 using SampleMeta = std::tuple<std::vector<uint64_t>, json>;
48
49 // The data struct is used to cache meta info when load mode is slow load
50 // task_type: kCommonTask / kPaddedTask
51 // shard_id: the index of mindrecord files
52 // start: the global index of all the samples
53 // end: the global index of all the samples
54 struct PartitionedShardSampleCount {
55 TaskType task_type;
56 int32_t shard_id;
57 int64_t start;
58 int64_t end;
59 };
60
61 class MINDRECORD_API GeneratorIds {
62 public:
63 GeneratorIds();
64 void SetShardSampleCount(const std::vector<PartitionedShardSampleCount> &partitioned_shard_sample_count);
65 void ResetShardIndexAndID();
66 std::vector<int64_t> GetNextSampleIds(const bool &need_shuffle, const uint32_t &seed);
67
68 private:
69 // example:
70 // kCommonTask, 4, 100, 250
71 // kCommonTask, 5, 250, 700
72 // kCommonTask, 0, 0, 15
73 std::vector<PartitionedShardSampleCount> partitioned_shard_sample_count_;
74 int32_t partition_index_;
75 int64_t partition_sample_index_;
76 };
77
78 // There are three load mode
79 // fast mode: use ShardTask to cache meta data
80 // lazy mode: use TaskInfo to cache meta data
81 // slow mode: just cache shard_id:sample_count
82 class MINDRECORD_API ShardTaskList {
83 public:
84 ShardTaskList();
85
86 ShardTaskList(const ShardTaskList &task); // copy construction
87
88 ShardTaskList &operator=(const ShardTaskList &task); // assignment operator
89
90 ~ShardTaskList() = default;
91
92 void InitSampleIds();
93
94 static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks);
95
96 // Assigns the task based on task id
97 inline void AssignTask(ShardTaskList &sourceTasks, int64_t id); // NOLINT
98
99 inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
100 const json &label);
101
102 inline void InsertTask(const int64_t &i, TaskType task_type, int shard_id, int group_id,
103 const std::vector<uint64_t> &offset, const json &label);
104
105 inline void InsertTask(ShardTask task);
106
107 inline void InsertTask(const int64_t &i, ShardTask task);
108
109 void MakePerm();
110
111 inline void InsertSampleId(int64_t id);
112
113 void PopBack();
114
115 int64_t Size() const;
116
117 int64_t SizeAfterSampling() const;
118
119 int64_t SizeOfRows() const;
120
121 ShardTask GetTaskByID(int64_t id);
122
123 ShardTask GetRandomTask();
124
125 int64_t GetTaskSampleByID(int64_t id);
126
127 int64_t GetRandomTaskID();
128
129 static ShardTaskList Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, // NOLINT
130 int64_t num_elements, int64_t num_samples);
131
132 inline void ResizeTask(const int64_t &size);
133
134 // used for slow load mode
135 void SetShardSampleCount(const std::vector<int64_t> &shard_sample_count);
136 void SetPaddedSample(const int32_t &padded_sample);
137 void SetFileIds(const std::vector<int32_t> &file_ids);
138 void SetShuffledShardSampleCount(const std::vector<int64_t> &shuffled_shard_sample_count);
139 void SetPartitionedShardSampleCount(const std::vector<PartitionedShardSampleCount> &partitioned_shard_sample_count);
140 void UpdatePartitionedShardSampleCountByNumSamples(const int64_t &num_samples);
141 std::vector<int64_t> GetNextSampleIds();
142
143 uint32_t categories;
144
145 // >>>> fast load meta data & lazy load meta data >>>>
146 std::vector<int64_t> permutation_; // A list of ints used for shuffling sample ids
147
148 std::vector<int64_t> sample_ids_; // The list of actual ids that were sampled
149
150 // fast mode: [{TaskType, (shard_id, group_id(fast load))}, ...]
151 // lazy mode: [{TaskType, (shard_id, sample_id(lazy load))}, ...]
152 std::vector<TaskInfo> task_list_;
153
154 // fast mode: [{[blob_start, blob_end], json}, ...]
155 // lazy mode: none
156 std::vector<SampleMeta> sample_meta_list_;
157
158 // >>>> slow load meta data >>>>
159 // indicate shard_id : inc_count
160 // 0 : 15 - shard 0 has 15 samples
161 // 1 : 41 - shard 1 has 26 samples
162 // 2 : 58 - shard 2 has 17 samples
163 std::vector<int64_t> shard_sample_count_;
164 int32_t padded_sample_; // the padded sample
165 // shuffle shard indexes from 0,1,2 to 2,0,1
166 std::vector<int32_t> file_ids_; // shuffle file names in each epoch
167 // after shuffle
168 // 0 : 17 - shard 0 has 17 samples - pre shard 2
169 // 1 : 32 - shard 1 has 15 samples - pre shard 0
170 // 2 : 58 - shard 2 has 26 samples - pre shard 1
171 std::vector<int64_t> shuffled_shard_sample_count_;
172 // Assuming this is an 8-card training
173 // card 0 : kCommonTask, 0, 0, 8
174 // card 1 : kCommonTask, 0, 8, 16
175 // card 2 : kCommonTask, 0, 16, 17
176 // card 2 : kCommonTask, 1, 17, 24
177 // card 3 : kCommonTask, 1, 24, 32
178 // card 4 : kCommonTask, 2, 32, 40
179 // card 5 : kCommonTask, 2, 40, 48
180 // card 6 : kCommonTask, 2, 48, 56
181 // card 7 : kCommonTask, 2, 56, 58
182 // card 7 : kPaddedTask, -1, 58, 64
183 std::vector<PartitionedShardSampleCount> partitioned_shard_sample_count_;
184 // need shuffle the samples
185 bool need_shuffle_;
186 // the shuffle seed is from ShuffleOperator which is changed in every epoch
187 uint32_t shuffle_seed_;
188 // this can generator sample ids which are from partitioned_shard_sample_count_
189 GeneratorIds generator_ids_;
190
191 // load type: fast mode, lazy mode or slow mode
192 LoadMode load_mode_;
193 };
194
AssignTask(ShardTaskList & sourceTasks,int64_t id)195 inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, int64_t id) {
196 // Insert the sample id from the source into ourself by indexing at id position.
197 // Important: The task list itself does not change.
198 int64_t sample_id = sourceTasks.GetTaskSampleByID(id);
199 MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id;
200 sample_ids_.push_back(sample_id);
201 }
202
InsertTask(TaskType task_type,int shard_id,int group_id,const std::vector<uint64_t> & offset,const json & label)203 inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id,
204 const std::vector<uint64_t> &offset, const json &label) {
205 MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
206 << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
207 (void)task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id));
208 if (load_mode_ == LoadMode::kFast) {
209 (void)sample_meta_list_.emplace_back(offset, label);
210 }
211 }
212
InsertTask(const int64_t & i,TaskType task_type,int shard_id,int group_id,const std::vector<uint64_t> & offset,const json & label)213 inline void ShardTaskList::InsertTask(const int64_t &i, TaskType task_type, int shard_id, int group_id,
214 const std::vector<uint64_t> &offset, const json &label) {
215 MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
216 << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
217 task_list_[i] = {task_type, std::make_tuple(shard_id, group_id)};
218 if (load_mode_ == LoadMode::kFast) {
219 sample_meta_list_[i] = {offset, label};
220 }
221 }
222
InsertTask(ShardTask task)223 inline void ShardTaskList::InsertTask(ShardTask task) {
224 MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task))
225 << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
226 << ", size of task_list_: " << task_list_.size() << ".";
227 task_list_.push_back({std::get<0>(task), std::get<1>(task)});
228 if (load_mode_ == LoadMode::kFast) {
229 sample_meta_list_.push_back({std::get<2>(task), std::get<3>(task)});
230 }
231 }
232
InsertTask(const int64_t & i,ShardTask task)233 inline void ShardTaskList::InsertTask(const int64_t &i, ShardTask task) {
234 task_list_[i] = {std::get<0>(task), std::get<1>(task)};
235 if (load_mode_ == kFast) {
236 sample_meta_list_[i] = {std::get<2>(task), std::get<3>(task)};
237 }
238 }
239
ResizeTask(const int64_t & size)240 inline void ShardTaskList::ResizeTask(const int64_t &size) {
241 task_list_.resize(size);
242 if (load_mode_ == kFast) {
243 sample_meta_list_.resize(size);
244 }
245 }
246 } // namespace mindrecord
247 } // namespace mindspore
248
249 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
250