• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_
19 
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "minddata/mindrecord/include/common/shard_utils.h"
26 #include "minddata/mindrecord/include/mindrecord_macro.h"
27 #include "minddata/mindrecord/include/shard_error.h"
28 #include "minddata/mindrecord/include/shard_index.h"
29 #include "minddata/mindrecord/include/shard_page.h"
30 #include "minddata/mindrecord/include/shard_schema.h"
31 #include "minddata/mindrecord/include/shard_statistics.h"
32 
33 namespace mindspore {
34 namespace mindrecord {
35 class MINDRECORD_API ShardHeader {
36  public:
37   ShardHeader();
38 
39   ~ShardHeader() = default;
40 
41   Status BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);
42 
43   static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr);
44   /// \brief add the schema and save it
45   /// \param[in] schema the schema needs to be added
46   /// \return the last schema's id
47   int AddSchema(std::shared_ptr<Schema> schema);
48 
49   /// \brief add the statistic and save it
50   /// \param[in] statistic the statistic needs to be added
51   /// \return the last statistic's id
52   void AddStatistic(std::shared_ptr<Statistics> statistic);
53 
54   /// \brief create index and add fields which from schema for each schema
55   /// \param[in] fields the index fields needs to be added
56   /// \return SUCCESS if add successfully, FAILED if not
57   Status AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields);
58 
59   Status AddIndexFields(const std::vector<std::string> &fields);
60 
61   /// \brief get the schema
62   /// \return the schema
63   std::vector<std::shared_ptr<Schema>> GetSchemas();
64 
65   /// \brief get Statistics
66   /// \return the Statistic
67   std::vector<std::shared_ptr<Statistics>> GetStatistics();
68 
69   /// \brief add the statistic and save it
70   /// \param[in] statistic info of slim size
71   /// \return null
72   int64_t GetSlimSizeStatistic(const json &slim_size_json);
73 
74   /// \brief get the fields of the index
75   /// \return the fields of the index
76   std::vector<std::pair<uint64_t, std::string>> GetFields();
77 
78   /// \brief get the index
79   /// \return the index
80   std::shared_ptr<Index> GetIndex();
81 
82   /// \brief get the schema by schemaid
83   /// \param[in] schema_id the id of schema needs to be got
84   /// \param[in] schema_ptr the schema obtained by schemaId
85   /// \return Status
86   Status GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr);
87 
88   /// \brief get the filepath to shard by shardID
89   /// \param[in] shardID the id of shard which filepath needs to be obtained
90   /// \return the filepath obtained by shardID
91   std::string GetShardAddressByID(int64_t shard_id);
92 
93   /// \brief get the statistic by statistic id
94   /// \param[in] statistic_id the id of statistic needs to be get
95   /// \param[in] statistics_ptr the statistics obtained by statistic id
96   /// \return Status
97   Status GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr);
98 
99   Status InitByFiles(const std::vector<std::string> &file_paths);
100 
SetIndex(Index index)101   void SetIndex(Index index) { index_ = std::make_shared<Index>(index); }
102 
103   Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr);
104 
105   Status SetPage(const std::shared_ptr<Page> &new_page);
106 
107   Status AddPage(const std::shared_ptr<Page> &new_page);
108 
109   int64_t GetLastPageId(const int &shard_id);
110 
111   int GetLastPageIdByType(const int &shard_id, const std::string &page_type);
112 
113   Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr);
114 
GetShardAddresses()115   std::vector<std::string> GetShardAddresses() const { return shard_addresses_; }
116 
GetShardCount()117   int GetShardCount() const { return shard_count_; }
118 
GetSchemaCount()119   int GetSchemaCount() const { return schema_.size(); }
120 
GetHeaderSize()121   uint64_t GetHeaderSize() const { return header_size_; }
122 
GetPageSize()123   uint64_t GetPageSize() const { return page_size_; }
124 
GetCompressionSize()125   uint64_t GetCompressionSize() const { return compression_size_; }
126 
SetHeaderSize(const uint64_t & header_size)127   void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; }
128 
SetPageSize(const uint64_t & page_size)129   void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
130 
SetCompressionSize(const uint64_t & compression_size)131   void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; }
132 
133   std::vector<std::string> SerializeHeader();
134 
135   Status PagesToFile(const std::string dump_file_name);
136 
137   Status FileToPages(const std::string dump_file_name);
138 
139   static Status Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
140                            const std::vector<std::string> &index_fields,
141                            std::vector<std::string> &blob_fields,  // NOLINT
142                            uint64_t &schema_id);                   // NOLINT
143 
144  private:
145   Status InitializeHeader(const std::vector<json> &headers, bool load_dataset);
146 
147   /// \brief get the headers from all the shard data
148   /// \param[in] the shard data real path
149   /// \param[in] the headers which read from the shard data
150   /// \return SUCCESS/FAILED
151   Status GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers);  // NOLINT
152 
153   Status ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
154 
155   /// \brief check the binary file status
156   static Status CheckFileStatus(const std::string &path);
157 
158   static Status ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr);
159 
160   void GetHeadersOneTask(int start, int end, std::vector<json> &headers,  // NOLINT
161                          const vector<string> &realAddresses);
162 
163   Status ParseIndexFields(const json &index_fields);
164 
165   Status CheckIndexField(const std::string &field, const json &schema);
166 
167   Status ParsePage(const json &page, int shard_index, bool load_dataset);
168 
169   Status ParseStatistics(const json &statistics);
170 
171   Status ParseSchema(const json &schema);
172 
173   void ParseShardAddress(const json &address);
174 
175   std::string SerializeIndexFields();
176 
177   std::vector<std::string> SerializePage();
178 
179   std::string SerializeStatistics();
180 
181   std::string SerializeSchema();
182 
183   std::string SerializeShardAddress();
184 
185   std::shared_ptr<Index> InitIndexPtr();
186 
187   Status GetAllSchemaID(std::set<uint64_t> &bucket_count);  // NOLINT
188 
189   uint32_t shard_count_;
190   uint64_t header_size_;
191   uint64_t page_size_;
192   uint64_t compression_size_;
193 
194   std::shared_ptr<Index> index_;
195   std::vector<std::string> shard_addresses_;
196   std::vector<std::shared_ptr<Schema>> schema_;
197   std::vector<std::shared_ptr<Statistics>> statistics_;
198   std::vector<std::vector<std::shared_ptr<Page>>> pages_;
199 };
200 }  // namespace mindrecord
201 }  // namespace mindspore
202 
203 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_
204