• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
19 
20 #include <algorithm>
21 #include <iostream>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 #include "minddata/mindrecord/include/common/shard_utils.h"
27 
28 namespace mindspore {
29 namespace mindrecord {
30 
31 // The data struct is as below:
32 // 1. TaskType: kCommonTask / kPaddedTask
33 // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
34 // 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
35 using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>;
36 
37 class __attribute__((visibility("default"))) ShardTaskList {
38  public:
39   ShardTaskList();
40 
41   ShardTaskList(const ShardTaskList &task);  // copy construction
42 
43   ShardTaskList &operator=(const ShardTaskList &task);  // assignment operator
44 
45   ~ShardTaskList() = default;
46 
47   void InitSampleIds();
48 
49   static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks);
50 
51   // Assigns the task based on task id
52   inline void AssignTask(ShardTaskList &sourceTasks, size_t id);
53 
54   inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
55                          const json &label);
56 
57   inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
58                          const std::vector<uint64_t> &offset, const json &label);
59 
60   inline void InsertTask(ShardTask task);
61 
62   inline void InsertTask(const uint32_t &i, ShardTask task);
63 
64   void MakePerm();
65 
66   inline void InsertSampleId(int id);
67 
68   void PopBack();
69 
70   uint32_t Size() const;
71 
72   uint32_t SizeOfRows() const;
73 
74   ShardTask &GetTaskByID(size_t id);
75 
76   ShardTask &GetRandomTask();
77 
78   int GetTaskSampleByID(size_t id);
79 
80   int GetRandomTaskID();
81 
82   static ShardTaskList Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
83                                int64_t num_samples);
84 
85   inline void ResizeTask(const uint32_t &size);
86 
87   uint32_t categories;
88 
89   std::vector<int> permutation_;  // A list of ints used for shuffling sample ids
90 
91   std::vector<int> sample_ids_;  // The list of actual ids that were sampled
92 
93   std::vector<ShardTask> task_list_;  // The full list of tasks
94 };
95 
AssignTask(ShardTaskList & sourceTasks,size_t id)96 inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, size_t id) {
97   // Insert the sample id from the source into ourself by indexing at id position.
98   // Important: The task list itself does not change.
99   int sample_id = sourceTasks.GetTaskSampleByID(id);
100   MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id;
101   sample_ids_.push_back(sample_id);
102 }
103 
InsertTask(TaskType task_type,int shard_id,int group_id,const std::vector<uint64_t> & offset,const json & label)104 inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id,
105                                       const std::vector<uint64_t> &offset, const json &label) {
106   MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
107                 << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
108   task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
109 }
110 
InsertTask(const uint32_t & i,TaskType task_type,int shard_id,int group_id,const std::vector<uint64_t> & offset,const json & label)111 inline void ShardTaskList::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
112                                       const std::vector<uint64_t> &offset, const json &label) {
113   MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
114                 << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
115   task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
116 }
117 
InsertTask(ShardTask task)118 inline void ShardTaskList::InsertTask(ShardTask task) {
119   MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task))
120                 << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
121                 << ", size of task_list_: " << task_list_.size() << ".";
122 
123   task_list_.push_back(std::move(task));
124 }
125 
InsertTask(const uint32_t & i,ShardTask task)126 inline void ShardTaskList::InsertTask(const uint32_t &i, ShardTask task) { task_list_[i] = std::move(task); }
127 
ResizeTask(const uint32_t & size)128 inline void ShardTaskList::ResizeTask(const uint32_t &size) { task_list_.resize(size); }
129 }  // namespace mindrecord
130 }  // namespace mindspore
131 
132 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
133