• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_
19 
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "minddata/mindrecord/include/common/shard_utils.h"
26 #include "minddata/mindrecord/include/shard_error.h"
27 #include "minddata/mindrecord/include/shard_index.h"
28 #include "minddata/mindrecord/include/shard_page.h"
29 #include "minddata/mindrecord/include/shard_schema.h"
30 #include "minddata/mindrecord/include/shard_statistics.h"
31 
32 namespace mindspore {
33 namespace mindrecord {
34 class __attribute__((visibility("default"))) ShardHeader {
35  public:
36   ShardHeader();
37 
38   ~ShardHeader() = default;
39 
40   Status BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);
41 
42   static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr);
43   /// \brief add the schema and save it
44   /// \param[in] schema the schema needs to be added
45   /// \return the last schema's id
46   int AddSchema(std::shared_ptr<Schema> schema);
47 
48   /// \brief add the statistic and save it
49   /// \param[in] statistic the statistic needs to be added
50   /// \return the last statistic's id
51   void AddStatistic(std::shared_ptr<Statistics> statistic);
52 
53   /// \brief create index and add fields which from schema for each schema
54   /// \param[in] fields the index fields needs to be added
55   /// \return SUCCESS if add successfully, FAILED if not
56   Status AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields);
57 
58   Status AddIndexFields(const std::vector<std::string> &fields);
59 
60   /// \brief get the schema
61   /// \return the schema
62   std::vector<std::shared_ptr<Schema>> GetSchemas();
63 
64   /// \brief get Statistics
65   /// \return the Statistic
66   std::vector<std::shared_ptr<Statistics>> GetStatistics();
67 
68   /// \brief add the statistic and save it
69   /// \param[in] statistic info of slim size
70   /// \return null
71   int64_t GetSlimSizeStatistic(const json &slim_size_json);
72 
73   /// \brief get the fields of the index
74   /// \return the fields of the index
75   std::vector<std::pair<uint64_t, std::string>> GetFields();
76 
77   /// \brief get the index
78   /// \return the index
79   std::shared_ptr<Index> GetIndex();
80 
81   /// \brief get the schema by schemaid
82   /// \param[in] schema_id the id of schema needs to be got
83   /// \param[in] schema_ptr the schema obtained by schemaId
84   /// \return Status
85   Status GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr);
86 
87   /// \brief get the filepath to shard by shardID
88   /// \param[in] shardID the id of shard which filepath needs to be obtained
89   /// \return the filepath obtained by shardID
90   std::string GetShardAddressByID(int64_t shard_id);
91 
92   /// \brief get the statistic by statistic id
93   /// \param[in] statistic_id the id of statistic needs to be get
94   /// \param[in] statistics_ptr the statistics obtained by statistic id
95   /// \return Status
96   Status GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr);
97 
98   Status InitByFiles(const std::vector<std::string> &file_paths);
99 
SetIndex(Index index)100   void SetIndex(Index index) { index_ = std::make_shared<Index>(index); }
101 
102   Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr);
103 
104   Status SetPage(const std::shared_ptr<Page> &new_page);
105 
106   Status AddPage(const std::shared_ptr<Page> &new_page);
107 
108   int64_t GetLastPageId(const int &shard_id);
109 
110   int GetLastPageIdByType(const int &shard_id, const std::string &page_type);
111 
112   Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr);
113 
GetShardAddresses()114   std::vector<std::string> GetShardAddresses() const { return shard_addresses_; }
115 
GetShardCount()116   int GetShardCount() const { return shard_count_; }
117 
GetSchemaCount()118   int GetSchemaCount() const { return schema_.size(); }
119 
GetHeaderSize()120   uint64_t GetHeaderSize() const { return header_size_; }
121 
GetPageSize()122   uint64_t GetPageSize() const { return page_size_; }
123 
GetCompressionSize()124   uint64_t GetCompressionSize() const { return compression_size_; }
125 
SetHeaderSize(const uint64_t & header_size)126   void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; }
127 
SetPageSize(const uint64_t & page_size)128   void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
129 
SetCompressionSize(const uint64_t & compression_size)130   void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; }
131 
132   std::vector<std::string> SerializeHeader();
133 
134   Status PagesToFile(const std::string dump_file_name);
135 
136   Status FileToPages(const std::string dump_file_name);
137 
138   static Status Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
139                            const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
140                            uint64_t &schema_id);
141 
142  private:
143   Status InitializeHeader(const std::vector<json> &headers, bool load_dataset);
144 
145   /// \brief get the headers from all the shard data
146   /// \param[in] the shard data real path
147   /// \param[in] the headers which read from the shard data
148   /// \return SUCCESS/FAILED
149   Status GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers);
150 
151   Status ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
152 
153   /// \brief check the binary file status
154   static Status CheckFileStatus(const std::string &path);
155 
156   static Status ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr);
157 
158   void GetHeadersOneTask(int start, int end, std::vector<json> &headers, const vector<string> &realAddresses);
159 
160   Status ParseIndexFields(const json &index_fields);
161 
162   Status CheckIndexField(const std::string &field, const json &schema);
163 
164   Status ParsePage(const json &page, int shard_index, bool load_dataset);
165 
166   Status ParseStatistics(const json &statistics);
167 
168   Status ParseSchema(const json &schema);
169 
170   void ParseShardAddress(const json &address);
171 
172   std::string SerializeIndexFields();
173 
174   std::vector<std::string> SerializePage();
175 
176   std::string SerializeStatistics();
177 
178   std::string SerializeSchema();
179 
180   std::string SerializeShardAddress();
181 
182   std::shared_ptr<Index> InitIndexPtr();
183 
184   Status GetAllSchemaID(std::set<uint64_t> &bucket_count);
185 
186   uint32_t shard_count_;
187   uint64_t header_size_;
188   uint64_t page_size_;
189   uint64_t compression_size_;
190 
191   std::shared_ptr<Index> index_;
192   std::vector<std::string> shard_addresses_;
193   std::vector<std::shared_ptr<Schema>> schema_;
194   std::vector<std::shared_ptr<Statistics>> statistics_;
195   std::vector<std::vector<std::shared_ptr<Page>>> pages_;
196 };
197 }  // namespace mindrecord
198 }  // namespace mindspore
199 
200 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_
201