1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
19
20 #include <algorithm>
21 #include <iostream>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 #include "minddata/mindrecord/include/common/shard_utils.h"
27
28 namespace mindspore {
29 namespace mindrecord {
30
31 // The data struct is as below:
32 // 1. TaskType: kCommonTask / kPaddedTask
33 // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
34 // 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
35 using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>;
36
37 class __attribute__((visibility("default"))) ShardTaskList {
38 public:
39 ShardTaskList();
40
41 ShardTaskList(const ShardTaskList &task); // copy construction
42
43 ShardTaskList &operator=(const ShardTaskList &task); // assignment operator
44
45 ~ShardTaskList() = default;
46
47 void InitSampleIds();
48
49 static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks);
50
51 // Assigns the task based on task id
52 inline void AssignTask(ShardTaskList &sourceTasks, size_t id);
53
54 inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
55 const json &label);
56
57 inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
58 const std::vector<uint64_t> &offset, const json &label);
59
60 inline void InsertTask(ShardTask task);
61
62 inline void InsertTask(const uint32_t &i, ShardTask task);
63
64 void MakePerm();
65
66 inline void InsertSampleId(int id);
67
68 void PopBack();
69
70 uint32_t Size() const;
71
72 uint32_t SizeOfRows() const;
73
74 ShardTask &GetTaskByID(size_t id);
75
76 ShardTask &GetRandomTask();
77
78 int GetTaskSampleByID(size_t id);
79
80 int GetRandomTaskID();
81
82 static ShardTaskList Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
83 int64_t num_samples);
84
85 inline void ResizeTask(const uint32_t &size);
86
87 uint32_t categories;
88
89 std::vector<int> permutation_; // A list of ints used for shuffling sample ids
90
91 std::vector<int> sample_ids_; // The list of actual ids that were sampled
92
93 std::vector<ShardTask> task_list_; // The full list of tasks
94 };
95
AssignTask(ShardTaskList & sourceTasks,size_t id)96 inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, size_t id) {
97 // Insert the sample id from the source into ourself by indexing at id position.
98 // Important: The task list itself does not change.
99 int sample_id = sourceTasks.GetTaskSampleByID(id);
100 MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id;
101 sample_ids_.push_back(sample_id);
102 }
103
InsertTask(TaskType task_type,int shard_id,int group_id,const std::vector<uint64_t> & offset,const json & label)104 inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id,
105 const std::vector<uint64_t> &offset, const json &label) {
106 MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
107 << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
108 task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
109 }
110
InsertTask(const uint32_t & i,TaskType task_type,int shard_id,int group_id,const std::vector<uint64_t> & offset,const json & label)111 inline void ShardTaskList::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
112 const std::vector<uint64_t> &offset, const json &label) {
113 MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
114 << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
115 task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
116 }
117
InsertTask(ShardTask task)118 inline void ShardTaskList::InsertTask(ShardTask task) {
119 MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task))
120 << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
121 << ", size of task_list_: " << task_list_.size() << ".";
122
123 task_list_.push_back(std::move(task));
124 }
125
InsertTask(const uint32_t & i,ShardTask task)126 inline void ShardTaskList::InsertTask(const uint32_t &i, ShardTask task) { task_list_[i] = std::move(task); }
127
ResizeTask(const uint32_t & size)128 inline void ShardTaskList::ResizeTask(const uint32_t &size) { task_list_.resize(size); }
129 } // namespace mindrecord
130 } // namespace mindspore
131
132 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
133