• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 #include "minddata/mindrecord/include/shard_reader.h"
27 
28 namespace mindspore {
29 namespace mindrecord {
30 using CATEGORY_INFO = std::vector<std::tuple<int, std::string, int>>;
31 using PAGES = std::vector<std::tuple<std::vector<uint8_t>, json>>;
32 using PAGES_LOAD = std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>;
33 using PAGES_WITH_BLOBS = std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>;
34 using PAGES_LOAD_WITH_BLOBS = std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, pybind11::object>>;
35 
36 class MINDRECORD_API ShardSegment : public ShardReader {
37  public:
38   ShardSegment();
39 
40   ~ShardSegment() override = default;
41 
42   /// \brief Get candidate category fields
43   /// \return a list of fields names which are the candidates of category
44   Status GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr);
45 
46   /// \brief Set category field
47   /// \param[in] category_field category name
48   /// \return true if category name is existed
49   Status SetCategoryField(std::string category_field);
50 
51   /// \brief Thread-safe implementation of ReadCategoryInfo
52   /// \return statistics data in json format with 2 field: "key" and "categories".
53   ///         The value of "categories" is a list. Each Element in list is {count, id, name}
54   ///              count: count of images in category
55   ///              id:    internal unique identification, persistent
56   ///              name:  category name
57   ///  example:
58   /// { "key": "label",
59   ///    "categories": [ { "count": 3, "id": 0, "name": "sport", },
60   ///                    { "count": 3, "id": 1, "name": "finance", } ] }
61   Status ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr);
62 
63   /// \brief Thread-safe implementation of ReadAtPageById
64   /// \param[in] category_id category ID
65   /// \param[in] page_no page number
66   /// \param[in] n_rows_of_page rows number in one page
67   /// \return images array, image is a vector of uint8_t
68   Status ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
69                         std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr);
70 
71   /// \brief Thread-safe implementation of ReadAtPageByName
72   /// \param[in] category_name category Name
73   /// \param[in] page_no page number
74   /// \param[in] n_rows_of_page rows number in one page
75   /// \return images array, image is a vector of uint8_t
76   Status ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
77                           std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr);
78 
79   Status ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
80                            std::shared_ptr<PAGES_WITH_BLOBS> *pages_ptr);
81 
82   Status ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
83                              std::shared_ptr<PAGES_WITH_BLOBS> *pages_ptr);
84 
85   std::pair<ShardType, std::vector<std::string>> GetBlobFields();
86 
87  private:
88   Status WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr);
89 
90   std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec);
91 
92   std::string CleanUp(std::string fieldName);
93 
94   Status PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
95                     std::shared_ptr<std::vector<uint8_t>> *images_ptr);
96 
97   std::vector<std::string> candidate_category_fields_;
98   std::string current_category_field_;
99   const uint32_t kStartFieldId = 9;
100 };
101 }  // namespace mindrecord
102 }  // namespace mindspore
103 
104 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_
105