1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ 19 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 #include "minddata/mindrecord/include/common/shard_utils.h" 26 #include "minddata/mindrecord/include/shard_error.h" 27 #include "minddata/mindrecord/include/shard_index.h" 28 #include "minddata/mindrecord/include/shard_page.h" 29 #include "minddata/mindrecord/include/shard_schema.h" 30 #include "minddata/mindrecord/include/shard_statistics.h" 31 32 namespace mindspore { 33 namespace mindrecord { 34 class __attribute__((visibility("default"))) ShardHeader { 35 public: 36 ShardHeader(); 37 38 ~ShardHeader() = default; 39 40 Status BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true); 41 42 static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr); 43 /// \brief add the schema and save it 44 /// \param[in] schema the schema needs to be added 45 /// \return the last schema's id 46 int AddSchema(std::shared_ptr<Schema> schema); 47 48 /// \brief add the statistic and save it 49 /// \param[in] statistic the statistic needs to be added 50 /// \return the last statistic's id 51 void AddStatistic(std::shared_ptr<Statistics> statistic); 52 53 /// \brief create index and add fields which from schema for each schema 54 /// \param[in] fields the index fields needs to be added 55 /// \return SUCCESS if add successfully, FAILED if not 56 Status AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields); 57 58 Status AddIndexFields(const std::vector<std::string> &fields); 59 60 /// \brief get the schema 61 /// \return the schema 62 std::vector<std::shared_ptr<Schema>> GetSchemas(); 63 64 /// \brief get Statistics 65 /// \return the Statistic 66 std::vector<std::shared_ptr<Statistics>> GetStatistics(); 67 68 /// \brief add the statistic and save it 69 /// \param[in] statistic info of slim size 70 /// \return null 71 int64_t GetSlimSizeStatistic(const json &slim_size_json); 72 73 /// \brief get the fields of the index 74 /// \return the fields of the index 75 std::vector<std::pair<uint64_t, std::string>> GetFields(); 76 77 /// \brief get the index 78 /// \return the index 79 std::shared_ptr<Index> GetIndex(); 80 81 /// \brief get the schema by schemaid 82 /// \param[in] schema_id the id of schema needs to be got 83 /// \param[in] schema_ptr the schema obtained by schemaId 84 /// \return Status 85 Status GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr); 86 87 /// \brief get the filepath to shard by shardID 88 /// \param[in] shardID the id of shard which filepath needs to be obtained 89 /// \return the filepath obtained by shardID 90 std::string GetShardAddressByID(int64_t shard_id); 91 92 /// \brief get the statistic by statistic id 93 /// \param[in] statistic_id the id of statistic needs to be get 94 /// \param[in] statistics_ptr the statistics obtained by statistic id 95 /// \return Status 96 Status GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr); 97 98 Status InitByFiles(const std::vector<std::string> &file_paths); 99 SetIndex(Index index)100 void SetIndex(Index index) { index_ = std::make_shared<Index>(index); } 101 102 Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr); 103 104 Status SetPage(const std::shared_ptr<Page> &new_page); 105 106 Status AddPage(const std::shared_ptr<Page> &new_page); 107 108 int64_t GetLastPageId(const int &shard_id); 109 110 int GetLastPageIdByType(const int &shard_id, const std::string &page_type); 111 112 Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr); 113 GetShardAddresses()114 std::vector<std::string> GetShardAddresses() const { return shard_addresses_; } 115 GetShardCount()116 int GetShardCount() const { return shard_count_; } 117 GetSchemaCount()118 int GetSchemaCount() const { return schema_.size(); } 119 GetHeaderSize()120 uint64_t GetHeaderSize() const { return header_size_; } 121 GetPageSize()122 uint64_t GetPageSize() const { return page_size_; } 123 GetCompressionSize()124 uint64_t GetCompressionSize() const { return compression_size_; } 125 SetHeaderSize(const uint64_t & header_size)126 void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } 127 SetPageSize(const uint64_t & page_size)128 void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } 129 SetCompressionSize(const uint64_t & compression_size)130 void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; } 131 132 std::vector<std::string> SerializeHeader(); 133 134 Status PagesToFile(const std::string dump_file_name); 135 136 Status FileToPages(const std::string dump_file_name); 137 138 static Status Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, 139 const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, 140 uint64_t &schema_id); 141 142 private: 143 Status InitializeHeader(const std::vector<json> &headers, bool load_dataset); 144 145 /// \brief get the headers from all the shard data 146 /// \param[in] the shard data real path 147 /// \param[in] the headers which read from the shard data 148 /// \return SUCCESS/FAILED 149 Status GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers); 150 151 Status ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); 152 153 /// \brief check the binary file status 154 static Status CheckFileStatus(const std::string &path); 155 156 static Status ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr); 157 158 void GetHeadersOneTask(int start, int end, std::vector<json> &headers, const vector<string> &realAddresses); 159 160 Status ParseIndexFields(const json &index_fields); 161 162 Status CheckIndexField(const std::string &field, const json &schema); 163 164 Status ParsePage(const json &page, int shard_index, bool load_dataset); 165 166 Status ParseStatistics(const json &statistics); 167 168 Status ParseSchema(const json &schema); 169 170 void ParseShardAddress(const json &address); 171 172 std::string SerializeIndexFields(); 173 174 std::vector<std::string> SerializePage(); 175 176 std::string SerializeStatistics(); 177 178 std::string SerializeSchema(); 179 180 std::string SerializeShardAddress(); 181 182 std::shared_ptr<Index> InitIndexPtr(); 183 184 Status GetAllSchemaID(std::set<uint64_t> &bucket_count); 185 186 uint32_t shard_count_; 187 uint64_t header_size_; 188 uint64_t page_size_; 189 uint64_t compression_size_; 190 191 std::shared_ptr<Index> index_; 192 std::vector<std::string> shard_addresses_; 193 std::vector<std::shared_ptr<Schema>> schema_; 194 std::vector<std::shared_ptr<Statistics>> statistics_; 195 std::vector<std::vector<std::shared_ptr<Page>>> pages_; 196 }; 197 } // namespace mindrecord 198 } // namespace mindspore 199 200 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_HEADER_H_ 201