1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ 19 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 #include "minddata/mindrecord/include/common/shard_utils.h" 26 #include "minddata/mindrecord/include/mindrecord_macro.h" 27 #include "minddata/mindrecord/include/shard_error.h" 28 #include "minddata/mindrecord/include/shard_index.h" 29 #include "minddata/mindrecord/include/shard_page.h" 30 #include "minddata/mindrecord/include/shard_schema.h" 31 #include "minddata/mindrecord/include/shard_statistics.h" 32 33 namespace mindspore { 34 namespace mindrecord { 35 class MINDRECORD_API ShardHeader { 36 public: 37 ShardHeader(); 38 39 ~ShardHeader() = default; 40 41 Status BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true); 42 43 static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr); 44 /// \brief add the schema and save it 45 /// \param[in] schema the schema needs to be added 46 /// \return the last schema's id 47 int AddSchema(std::shared_ptr<Schema> schema); 48 49 /// \brief add the statistic and save it 50 /// \param[in] statistic the statistic needs to be added 51 /// \return the last statistic's id 52 void AddStatistic(std::shared_ptr<Statistics> statistic); 53 54 /// \brief create index and add fields which from schema for each schema 55 /// \param[in] fields the index fields needs to be added 56 /// \return SUCCESS if add successfully, FAILED if not 57 Status AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields); 58 59 Status AddIndexFields(const std::vector<std::string> &fields); 60 61 /// \brief get the schema 62 /// \return the schema 63 std::vector<std::shared_ptr<Schema>> GetSchemas(); 64 65 /// \brief get Statistics 66 /// \return the Statistic 67 std::vector<std::shared_ptr<Statistics>> GetStatistics(); 68 69 /// \brief add the statistic and save it 70 /// \param[in] statistic info of slim size 71 /// \return null 72 int64_t GetSlimSizeStatistic(const json &slim_size_json); 73 74 /// \brief get the fields of the index 75 /// \return the fields of the index 76 std::vector<std::pair<uint64_t, std::string>> GetFields(); 77 78 /// \brief get the index 79 /// \return the index 80 std::shared_ptr<Index> GetIndex(); 81 82 /// \brief get the schema by schemaid 83 /// \param[in] schema_id the id of schema needs to be got 84 /// \param[in] schema_ptr the schema obtained by schemaId 85 /// \return Status 86 Status GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr); 87 88 /// \brief get the filepath to shard by shardID 89 /// \param[in] shardID the id of shard which filepath needs to be obtained 90 /// \return the filepath obtained by shardID 91 std::string GetShardAddressByID(int64_t shard_id); 92 93 /// \brief get the statistic by statistic id 94 /// \param[in] statistic_id the id of statistic needs to be get 95 /// \param[in] statistics_ptr the statistics obtained by statistic id 96 /// \return Status 97 Status GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr); 98 99 Status InitByFiles(const std::vector<std::string> &file_paths); 100 SetIndex(Index index)101 void SetIndex(Index index) { index_ = std::make_shared<Index>(index); } 102 103 Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr); 104 105 Status SetPage(const std::shared_ptr<Page> &new_page); 106 107 Status AddPage(const std::shared_ptr<Page> &new_page); 108 109 int64_t GetLastPageId(const int &shard_id); 110 111 int GetLastPageIdByType(const int &shard_id, const std::string &page_type); 112 113 Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr); 114 GetShardAddresses()115 std::vector<std::string> GetShardAddresses() const { return shard_addresses_; } 116 GetShardCount()117 int GetShardCount() const { return shard_count_; } 118 GetSchemaCount()119 int GetSchemaCount() const { return schema_.size(); } 120 GetHeaderSize()121 uint64_t GetHeaderSize() const { return header_size_; } 122 GetPageSize()123 uint64_t GetPageSize() const { return page_size_; } 124 GetCompressionSize()125 uint64_t GetCompressionSize() const { return compression_size_; } 126 SetHeaderSize(const uint64_t & header_size)127 void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } 128 SetPageSize(const uint64_t & page_size)129 void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } 130 SetCompressionSize(const uint64_t & compression_size)131 void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; } 132 133 std::vector<std::string> SerializeHeader(); 134 135 Status PagesToFile(const std::string dump_file_name); 136 137 Status FileToPages(const std::string dump_file_name); 138 139 static Status Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, 140 const std::vector<std::string> &index_fields, 141 std::vector<std::string> &blob_fields, // NOLINT 142 uint64_t &schema_id); // NOLINT 143 144 private: 145 Status InitializeHeader(const std::vector<json> &headers, bool load_dataset); 146 147 /// \brief get the headers from all the shard data 148 /// \param[in] the shard data real path 149 /// \param[in] the headers which read from the shard data 150 /// \return SUCCESS/FAILED 151 Status GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers); // NOLINT 152 153 Status ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); 154 155 /// \brief check the binary file status 156 static Status CheckFileStatus(const std::string &path); 157 158 static Status ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr); 159 160 void GetHeadersOneTask(int start, int end, std::vector<json> &headers, // NOLINT 161 const vector<string> &realAddresses); 162 163 Status ParseIndexFields(const json &index_fields); 164 165 Status CheckIndexField(const std::string &field, const json &schema); 166 167 Status ParsePage(const json &page, int shard_index, bool load_dataset); 168 169 Status ParseStatistics(const json &statistics); 170 171 Status ParseSchema(const json &schema); 172 173 void ParseShardAddress(const json &address); 174 175 std::string SerializeIndexFields(); 176 177 std::vector<std::string> SerializePage(); 178 179 std::string SerializeStatistics(); 180 181 std::string SerializeSchema(); 182 183 std::string SerializeShardAddress(); 184 185 std::shared_ptr<Index> InitIndexPtr(); 186 187 Status GetAllSchemaID(std::set<uint64_t> &bucket_count); // NOLINT 188 189 uint32_t shard_count_; 190 uint64_t header_size_; 191 uint64_t page_size_; 192 uint64_t compression_size_; 193 194 std::shared_ptr<Index> index_; 195 std::vector<std::string> shard_addresses_; 196 std::vector<std::shared_ptr<Schema>> schema_; 197 std::vector<std::shared_ptr<Statistics>> statistics_; 198 std::vector<std::vector<std::shared_ptr<Page>>> pages_; 199 }; 200 } // namespace mindrecord 201 } // namespace mindspore 202 203 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ 204