1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <tuple> 24 #include <utility> 25 #include <vector> 26 #include "minddata/mindrecord/include/shard_reader.h" 27 28 namespace mindspore { 29 namespace mindrecord { 30 using CATEGORY_INFO = std::vector<std::tuple<int, std::string, int>>; 31 using PAGES = std::vector<std::tuple<std::vector<uint8_t>, json>>; 32 using PAGES_LOAD = std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>; 33 using PAGES_WITH_BLOBS = std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>; 34 using PAGES_LOAD_WITH_BLOBS = std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, pybind11::object>>; 35 36 class MINDRECORD_API ShardSegment : public ShardReader { 37 public: 38 ShardSegment(); 39 40 ~ShardSegment() override = default; 41 42 /// \brief Get candidate category fields 43 /// \return a list of fields names which are the candidates of category 44 Status GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr); 45 46 /// \brief Set category field 47 /// \param[in] category_field category name 48 /// \return true if category name is existed 49 Status SetCategoryField(std::string category_field); 50 51 /// \brief Thread-safe implementation of ReadCategoryInfo 52 /// \return statistics data in json format with 2 field: "key" and "categories". 53 /// The value of "categories" is a list. Each Element in list is {count, id, name} 54 /// count: count of images in category 55 /// id: internal unique identification, persistent 56 /// name: category name 57 /// example: 58 /// { "key": "label", 59 /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, 60 /// { "count": 3, "id": 1, "name": "finance", } ] } 61 Status ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr); 62 63 /// \brief Thread-safe implementation of ReadAtPageById 64 /// \param[in] category_id category ID 65 /// \param[in] page_no page number 66 /// \param[in] n_rows_of_page rows number in one page 67 /// \return images array, image is a vector of uint8_t 68 Status ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, 69 std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr); 70 71 /// \brief Thread-safe implementation of ReadAtPageByName 72 /// \param[in] category_name category Name 73 /// \param[in] page_no page number 74 /// \param[in] n_rows_of_page rows number in one page 75 /// \return images array, image is a vector of uint8_t 76 Status ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, 77 std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr); 78 79 Status ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, 80 std::shared_ptr<PAGES_WITH_BLOBS> *pages_ptr); 81 82 Status ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, 83 std::shared_ptr<PAGES_WITH_BLOBS> *pages_ptr); 84 85 std::pair<ShardType, std::vector<std::string>> GetBlobFields(); 86 87 private: 88 Status WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr); 89 90 std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec); 91 92 std::string CleanUp(std::string fieldName); 93 94 Status PackImages(int group_id, int shard_id, std::vector<uint64_t> offset, 95 std::shared_ptr<std::vector<uint8_t>> *images_ptr); 96 97 std::vector<std::string> candidate_category_fields_; 98 std::string current_category_field_; 99 const uint32_t kStartFieldId = 9; 100 }; 101 } // namespace mindrecord 102 } // namespace mindspore 103 104 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ 105