• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
19 
20 #include <algorithm>
21 #include <iostream>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 #include "minddata/mindrecord/include/common/shard_utils.h"
27 #include "minddata/mindrecord/include/mindrecord_macro.h"
28 
29 namespace mindspore {
30 namespace mindrecord {
31 
32 // The data struct is as below:
33 // 1. TaskType: kCommonTask / kPaddedTask
34 // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
35 // 3. std::vector<uint64_t> : [blob_start, blob_end]
36 // 4. json : scalar_variable_fields
37 using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>;
38 
39 // The data struct is as below:
40 // 1. TaskType: kCommonTask / kPaddedTask
41 // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
42 using TaskInfo = std::tuple<TaskType, std::tuple<int, int>>;
43 
44 // The data struct is as below: contain the meta info
45 // 3. std::vector<uint64_t> : [blob_start, blob_end]
46 // 4. json : scalar_variable_fields
47 using SampleMeta = std::tuple<std::vector<uint64_t>, json>;
48 
49 // The data struct is used to cache meta info when load mode is slow load
50 // task_type: kCommonTask / kPaddedTask
51 // shard_id: the index of mindrecord files
52 // start: the global index of all the samples
53 // end: the global index of all the samples
54 struct PartitionedShardSampleCount {
55   TaskType task_type;
56   int32_t shard_id;
57   int64_t start;
58   int64_t end;
59 };
60 
61 class MINDRECORD_API GeneratorIds {
62  public:
63   GeneratorIds();
64   void SetShardSampleCount(const std::vector<PartitionedShardSampleCount> &partitioned_shard_sample_count);
65   void ResetShardIndexAndID();
66   std::vector<int64_t> GetNextSampleIds(const bool &need_shuffle, const uint32_t &seed);
67 
68  private:
69   // example:
70   // kCommonTask, 4, 100, 250
71   // kCommonTask, 5, 250, 700
72   // kCommonTask, 0, 0, 15
73   std::vector<PartitionedShardSampleCount> partitioned_shard_sample_count_;
74   int32_t partition_index_;
75   int64_t partition_sample_index_;
76 };
77 
78 // There are three load mode
79 // fast mode: use ShardTask to cache meta data
80 // lazy mode: use TaskInfo to cache meta data
81 // slow mode: just cache shard_id:sample_count
82 class MINDRECORD_API ShardTaskList {
83  public:
84   ShardTaskList();
85 
86   ShardTaskList(const ShardTaskList &task);  // copy construction
87 
88   ShardTaskList &operator=(const ShardTaskList &task);  // assignment operator
89 
90   ~ShardTaskList() = default;
91 
92   void InitSampleIds();
93 
94   static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks);
95 
96   // Assigns the task based on task id
97   inline void AssignTask(ShardTaskList &sourceTasks, int64_t id);  // NOLINT
98 
99   inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
100                          const json &label);
101 
102   inline void InsertTask(const int64_t &i, TaskType task_type, int shard_id, int group_id,
103                          const std::vector<uint64_t> &offset, const json &label);
104 
105   inline void InsertTask(ShardTask task);
106 
107   inline void InsertTask(const int64_t &i, ShardTask task);
108 
109   void MakePerm();
110 
111   inline void InsertSampleId(int64_t id);
112 
113   void PopBack();
114 
115   int64_t Size() const;
116 
117   int64_t SizeAfterSampling() const;
118 
119   int64_t SizeOfRows() const;
120 
121   ShardTask GetTaskByID(int64_t id);
122 
123   ShardTask GetRandomTask();
124 
125   int64_t GetTaskSampleByID(int64_t id);
126 
127   int64_t GetRandomTaskID();
128 
129   static ShardTaskList Combine(std::vector<ShardTaskList> &category_tasks, bool replacement,  // NOLINT
130                                int64_t num_elements, int64_t num_samples);
131 
132   inline void ResizeTask(const int64_t &size);
133 
134   // used for slow load mode
135   void SetShardSampleCount(const std::vector<int64_t> &shard_sample_count);
136   void SetPaddedSample(const int32_t &padded_sample);
137   void SetFileIds(const std::vector<int32_t> &file_ids);
138   void SetShuffledShardSampleCount(const std::vector<int64_t> &shuffled_shard_sample_count);
139   void SetPartitionedShardSampleCount(const std::vector<PartitionedShardSampleCount> &partitioned_shard_sample_count);
140   void UpdatePartitionedShardSampleCountByNumSamples(const int64_t &num_samples);
141   std::vector<int64_t> GetNextSampleIds();
142 
143   uint32_t categories;
144 
145   // >>>> fast load meta data & lazy load meta data >>>>
146   std::vector<int64_t> permutation_;  // A list of ints used for shuffling sample ids
147 
148   std::vector<int64_t> sample_ids_;  // The list of actual ids that were sampled
149 
150   // fast mode: [{TaskType, (shard_id, group_id(fast load))}, ...]
151   // lazy mode: [{TaskType, (shard_id, sample_id(lazy load))}, ...]
152   std::vector<TaskInfo> task_list_;
153 
154   // fast mode: [{[blob_start, blob_end], json}, ...]
155   // lazy mode: none
156   std::vector<SampleMeta> sample_meta_list_;
157 
158   // >>>> slow load meta data >>>>
159   // indicate shard_id : inc_count
160   // 0 : 15  -  shard 0 has 15 samples
161   // 1 : 41  -  shard 1 has 26 samples
162   // 2 : 58  -  shard 2 has 17 samples
163   std::vector<int64_t> shard_sample_count_;
164   int32_t padded_sample_;  // the padded sample
165   // shuffle shard indexes from 0,1,2 to 2,0,1
166   std::vector<int32_t> file_ids_;  // shuffle file names in each epoch
167   // after shuffle
168   // 0 : 17  -  shard 0 has 17 samples - pre shard 2
169   // 1 : 32  -  shard 1 has 15 samples - pre shard 0
170   // 2 : 58  -  shard 2 has 26 samples - pre shard 1
171   std::vector<int64_t> shuffled_shard_sample_count_;
172   // Assuming this is an 8-card training
173   // card 0 : kCommonTask, 0, 0, 8
174   // card 1 : kCommonTask, 0, 8, 16
175   // card 2 : kCommonTask, 0, 16, 17
176   // card 2 : kCommonTask, 1, 17, 24
177   // card 3 : kCommonTask, 1, 24, 32
178   // card 4 : kCommonTask, 2, 32, 40
179   // card 5 : kCommonTask, 2, 40, 48
180   // card 6 : kCommonTask, 2, 48, 56
181   // card 7 : kCommonTask, 2, 56, 58
182   // card 7 : kPaddedTask, -1, 58, 64
183   std::vector<PartitionedShardSampleCount> partitioned_shard_sample_count_;
184   // need shuffle the samples
185   bool need_shuffle_;
186   // the shuffle seed is from ShuffleOperator which is changed in every epoch
187   uint32_t shuffle_seed_;
188   // this can generator sample ids which are from partitioned_shard_sample_count_
189   GeneratorIds generator_ids_;
190 
191   // load type: fast mode, lazy mode or slow mode
192   LoadMode load_mode_;
193 };
194 
AssignTask(ShardTaskList & sourceTasks,int64_t id)195 inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, int64_t id) {
196   // Insert the sample id from the source into ourself by indexing at id position.
197   // Important: The task list itself does not change.
198   int64_t sample_id = sourceTasks.GetTaskSampleByID(id);
199   MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id;
200   sample_ids_.push_back(sample_id);
201 }
202 
InsertTask(TaskType task_type,int shard_id,int group_id,const std::vector<uint64_t> & offset,const json & label)203 inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id,
204                                       const std::vector<uint64_t> &offset, const json &label) {
205   MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
206                 << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
207   (void)task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id));
208   if (load_mode_ == LoadMode::kFast) {
209     (void)sample_meta_list_.emplace_back(offset, label);
210   }
211 }
212 
InsertTask(const int64_t & i,TaskType task_type,int shard_id,int group_id,const std::vector<uint64_t> & offset,const json & label)213 inline void ShardTaskList::InsertTask(const int64_t &i, TaskType task_type, int shard_id, int group_id,
214                                       const std::vector<uint64_t> &offset, const json &label) {
215   MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
216                 << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
217   task_list_[i] = {task_type, std::make_tuple(shard_id, group_id)};
218   if (load_mode_ == LoadMode::kFast) {
219     sample_meta_list_[i] = {offset, label};
220   }
221 }
222 
InsertTask(ShardTask task)223 inline void ShardTaskList::InsertTask(ShardTask task) {
224   MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task))
225                 << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
226                 << ", size of task_list_: " << task_list_.size() << ".";
227   task_list_.push_back({std::get<0>(task), std::get<1>(task)});
228   if (load_mode_ == LoadMode::kFast) {
229     sample_meta_list_.push_back({std::get<2>(task), std::get<3>(task)});
230   }
231 }
232 
InsertTask(const int64_t & i,ShardTask task)233 inline void ShardTaskList::InsertTask(const int64_t &i, ShardTask task) {
234   task_list_[i] = {std::get<0>(task), std::get<1>(task)};
235   if (load_mode_ == kFast) {
236     sample_meta_list_[i] = {std::get<2>(task), std::get<3>(task)};
237   }
238 }
239 
ResizeTask(const int64_t & size)240 inline void ShardTaskList::ResizeTask(const int64_t &size) {
241   task_list_.resize(size);
242   if (load_mode_ == kFast) {
243     sample_meta_list_.resize(size);
244   }
245 }
246 }  // namespace mindrecord
247 }  // namespace mindspore
248 
249 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
250