• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_
19 
20 #include <memory>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25 #include "minddata/mindrecord/include/shard_reader.h"
26 
27 namespace mindspore {
28 namespace mindrecord {
29 using CATEGORY_INFO = std::vector<std::tuple<int, std::string, int>>;
30 using PAGES = std::vector<std::tuple<std::vector<uint8_t>, json>>;
31 using PAGES_LOAD = std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>;
32 
33 class __attribute__((visibility("default"))) ShardSegment : public ShardReader {
34  public:
35   ShardSegment();
36 
37   ~ShardSegment() override = default;
38 
39   /// \brief Get candidate category fields
40   /// \return a list of fields names which are the candidates of category
41   Status GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr);
42 
43   /// \brief Set category field
44   /// \param[in] category_field category name
45   /// \return true if category name is existed
46   Status SetCategoryField(std::string category_field);
47 
48   /// \brief Thread-safe implementation of ReadCategoryInfo
49   /// \return statistics data in json format with 2 field: "key" and "categories".
50   ///         The value of "categories" is a list. Each Element in list is {count, id, name}
51   ///              count: count of images in category
52   ///              id:    internal unique identification, persistent
53   ///              name:  category name
54   ///  example:
55   /// { "key": "label",
56   ///    "categories": [ { "count": 3, "id": 0, "name": "sport", },
57   ///                    { "count": 3, "id": 1, "name": "finance", } ] }
58   Status ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr);
59 
60   /// \brief Thread-safe implementation of ReadAtPageById
61   /// \param[in] category_id category ID
62   /// \param[in] page_no page number
63   /// \param[in] n_rows_of_page rows number in one page
64   /// \return images array, image is a vector of uint8_t
65   Status ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
66                         std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr);
67 
68   /// \brief Thread-safe implementation of ReadAtPageByName
69   /// \param[in] category_name category Name
70   /// \param[in] page_no page number
71   /// \param[in] n_rows_of_page rows number in one page
72   /// \return images array, image is a vector of uint8_t
73   Status ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
74                           std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr);
75 
76   Status ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
77                            std::shared_ptr<PAGES> *pages_ptr);
78 
79   Status ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
80                              std::shared_ptr<PAGES> *pages_ptr);
81 
82   std::pair<ShardType, std::vector<std::string>> GetBlobFields();
83 
84  private:
85   Status WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr);
86 
87   std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec);
88 
89   std::string CleanUp(std::string fieldName);
90 
91   Status PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
92                     std::shared_ptr<std::vector<uint8_t>> *images_ptr);
93 
94   std::vector<std::string> candidate_category_fields_;
95   std::string current_category_field_;
96   const uint32_t kStartFieldId = 9;
97 };
98 }  // namespace mindrecord
99 }  // namespace mindspore
100 
101 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_
102