• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <vector>
19 #include "utils/ms_utils.h"
20 #include "minddata/mindrecord/include/common/shard_utils.h"
21 #include "minddata/mindrecord/include/shard_error.h"
22 #include "minddata/mindrecord/include/shard_index_generator.h"
23 #include "minddata/mindrecord/include/shard_reader.h"
24 #include "minddata/mindrecord/include/shard_segment.h"
25 #include "minddata/mindrecord/include/shard_writer.h"
26 #include "nlohmann/json.hpp"
27 #include "pybind11/pybind11.h"
28 #include "pybind11/stl.h"
29 #include "utils/log_adapter.h"
30 
31 namespace py = pybind11;
32 
33 using mindspore::LogStream;
34 using mindspore::ExceptionType::NoExceptionType;
35 using mindspore::MsLogLevel::ERROR;
36 
37 namespace mindspore {
38 namespace mindrecord {
39 #define THROW_IF_ERROR(s)                                      \
40   do {                                                         \
41     Status rc = std::move(s);                                  \
42     if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
43   } while (false)
44 
BindSchema(py::module * m)45 void BindSchema(py::module *m) {
46   (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local())
47     .def_static("build",
48                 [](const std::string &desc, const pybind11::handle &schema) {
49                   json schema_json = nlohmann::detail::ToJsonImpl(schema);
50                   return Schema::Build(std::move(desc), schema_json);
51                 })
52     .def("get_desc", &Schema::GetDesc)
53     .def("get_schema_content",
54          [](Schema &s) {
55            json schema_json = s.GetSchema();
56            return nlohmann::detail::FromJsonImpl(schema_json);
57          })
58     .def("get_blob_fields", &Schema::GetBlobFields)
59     .def("get_schema_id", &Schema::GetSchemaID);
60 }
61 
BindStatistics(const py::module * m)62 void BindStatistics(const py::module *m) {
63   (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local())
64     .def_static("build",
65                 [](const std::string desc, const pybind11::handle &statistics) {
66                   json statistics_json = nlohmann::detail::ToJsonImpl(statistics);
67                   return Statistics::Build(std::move(desc), statistics_json);
68                 })
69     .def("get_desc", &Statistics::GetDesc)
70     .def("get_statistics",
71          [](Statistics &s) {
72            json statistics_json = s.GetStatistics();
73            return nlohmann::detail::FromJsonImpl(statistics_json);
74          })
75     .def("get_statistics_id", &Statistics::GetStatisticsID);
76 }
77 
BindShardHeader(const py::module * m)78 void BindShardHeader(const py::module *m) {
79   (void)py::class_<ShardHeader, std::shared_ptr<ShardHeader>>(*m, "ShardHeader", py::module_local())
80     .def(py::init<>())
81     .def("add_schema", &ShardHeader::AddSchema)
82     .def("add_statistics", &ShardHeader::AddStatistic)
83     .def("add_index_fields",
84          [](ShardHeader &s, const std::vector<std::string> &fields) {
85            THROW_IF_ERROR(s.AddIndexFields(fields));
86            return SUCCESS;
87          })
88     .def("get_meta", &ShardHeader::GetSchemas)
89     .def("get_statistics", &ShardHeader::GetStatistics)
90     .def("get_fields", &ShardHeader::GetFields)
91     .def("get_schema_by_id",
92          [](ShardHeader &s, int64_t schema_id) {
93            std::shared_ptr<Schema> schema_ptr;
94            THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr));
95            return schema_ptr;
96          })
97     .def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) {
98       std::shared_ptr<Statistics> statistics_ptr;
99       THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr));
100       return statistics_ptr;
101     });
102 }
103 
BindShardWriter(py::module * m)104 void BindShardWriter(py::module *m) {
105   (void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local())
106     .def(py::init<>())
107     .def("open",
108          [](ShardWriter &s, const std::vector<std::string> &paths, bool append) {
109            THROW_IF_ERROR(s.Open(paths, append));
110            return SUCCESS;
111          })
112     .def("open_for_append",
113          [](ShardWriter &s, const std::string &path) {
114            THROW_IF_ERROR(s.OpenForAppend(path));
115            return SUCCESS;
116          })
117     .def("set_header_size",
118          [](ShardWriter &s, const uint64_t &header_size) {
119            THROW_IF_ERROR(s.SetHeaderSize(header_size));
120            return SUCCESS;
121          })
122     .def("set_page_size",
123          [](ShardWriter &s, const uint64_t &page_size) {
124            THROW_IF_ERROR(s.SetPageSize(page_size));
125            return SUCCESS;
126          })
127     .def("set_shard_header",
128          [](ShardWriter &s, std::shared_ptr<ShardHeader> header_data) {
129            THROW_IF_ERROR(s.SetShardHeader(header_data));
130            return SUCCESS;
131          })
132     .def("write_raw_data",
133          [](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
134             bool sign, bool parallel_writer) {
135            std::map<uint64_t, std::vector<json>> raw_data_json;
136            (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
137                                 [](const std::pair<uint64_t, std::vector<py::handle>> &p) {
138                                   auto &py_raw_data = p.second;
139                                   std::vector<json> json_raw_data;
140                                   (void)std::transform(
141                                     py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
142                                     [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
143                                   return std::make_pair(p.first, std::move(json_raw_data));
144                                 });
145            THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer));
146            return SUCCESS;
147          })
148     .def("commit", [](ShardWriter &s) {
149       THROW_IF_ERROR(s.Commit());
150       return SUCCESS;
151     });
152 }
153 
BindShardReader(const py::module * m)154 void BindShardReader(const py::module *m) {
155   (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
156     .def(py::init<>())
157     .def("open",
158          [](ShardReader &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
159             const std::vector<std::string> &selected_columns,
160             const std::vector<std::shared_ptr<ShardOperator>> &operators) {
161            THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
162            return SUCCESS;
163          })
164     .def("launch",
165          [](ShardReader &s) {
166            THROW_IF_ERROR(s.Launch(false));
167            return SUCCESS;
168          })
169     .def("get_header", &ShardReader::GetShardHeader)
170     .def("get_blob_fields", &ShardReader::GetBlobFields)
171     .def("get_next",
172          [](ShardReader &s) {
173            auto data = s.GetNext();
174            vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> res;
175            std::transform(data.begin(), data.end(), std::back_inserter(res),
176                           [&s](const std::tuple<std::vector<uint8_t>, json> &item) {
177                             auto &j = std::get<1>(item);
178                             pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
179                             auto blob_data_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
180                             (void)s.UnCompressBlob(std::get<0>(item), &blob_data_ptr);
181                             return std::make_tuple(*blob_data_ptr, std::move(obj));
182                           });
183            return res;
184          })
185     .def("close", &ShardReader::Close);
186 }
187 
BindShardIndexGenerator(const py::module * m)188 void BindShardIndexGenerator(const py::module *m) {
189   (void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local())
190     .def(py::init<const std::string &, bool>())
191     .def("build",
192          [](ShardIndexGenerator &s) {
193            THROW_IF_ERROR(s.Build());
194            return SUCCESS;
195          })
196     .def("write_to_db", [](ShardIndexGenerator &s) {
197       THROW_IF_ERROR(s.WriteToDatabase());
198       return SUCCESS;
199     });
200 }
201 
BindShardSegment(py::module * m)202 void BindShardSegment(py::module *m) {
203   (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
204     .def(py::init<>())
205     .def("open",
206          [](ShardSegment &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
207             const std::vector<std::string> &selected_columns,
208             const std::vector<std::shared_ptr<ShardOperator>> &operators) {
209            THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
210            return SUCCESS;
211          })
212     .def("get_category_fields",
213          [](ShardSegment &s) {
214            auto fields_ptr = std::make_shared<vector<std::string>>();
215            THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr));
216            return *fields_ptr;
217          })
218     .def("set_category_field",
219          [](ShardSegment &s, const std::string &category_field) {
220            THROW_IF_ERROR(s.SetCategoryField(category_field));
221            return SUCCESS;
222          })
223     .def("read_category_info",
224          [](ShardSegment &s) {
225            std::shared_ptr<std::string> category_ptr;
226            THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr));
227            return *category_ptr;
228          })
229     .def("read_at_page_by_id",
230          [](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
231            auto pages_load_ptr = std::make_shared<PAGES_LOAD>();
232            auto pages_ptr = std::make_shared<PAGES>();
233            THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr));
234            (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
235                                 [](const std::tuple<std::vector<uint8_t>, json> &item) {
236                                   auto &j = std::get<1>(item);
237                                   pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
238                                   return std::make_tuple(std::get<0>(item), std::move(obj));
239                                 });
240            return *pages_load_ptr;
241          })
242     .def("read_at_page_by_name",
243          [](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
244            auto pages_load_ptr = std::make_shared<PAGES_LOAD>();
245            auto pages_ptr = std::make_shared<PAGES>();
246            THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr));
247            (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
248                                 [](const std::tuple<std::vector<uint8_t>, json> &item) {
249                                   auto &j = std::get<1>(item);
250                                   pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
251                                   return std::make_tuple(std::get<0>(item), std::move(obj));
252                                 });
253            return *pages_load_ptr;
254          })
255     .def("get_header", &ShardSegment::GetShardHeader)
256     .def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); });
257 }
258 
BindGlobalParams(py::module * m)259 void BindGlobalParams(py::module *m) {
260   (*m).attr("MIN_HEADER_SIZE") = kMinHeaderSize;
261   (*m).attr("MAX_HEADER_SIZE") = kMaxHeaderSize;
262   (*m).attr("MIN_PAGE_SIZE") = kMinPageSize;
263   (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize;
264   (*m).attr("MIN_SHARD_COUNT") = kMinShardCount;
265   (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount;
266   (*m).attr("MAX_FILE_COUNT") = kMaxFileCount;
267   (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount;
268   (void)(*m).def("get_max_thread_num", &GetMaxThreadNum);
269 }
270 
PYBIND11_MODULE(_c_mindrecord,m)271 PYBIND11_MODULE(_c_mindrecord, m) {
272   m.doc() = "pybind11 mindrecord plugin";  // optional module docstring
273   (void)py::enum_<MSRStatus>(m, "MSRStatus", py::module_local())
274     .value("SUCCESS", SUCCESS)
275     .value("FAILED", FAILED)
276     .export_values();
277   (void)py::enum_<ShardType>(m, "ShardType", py::module_local()).value("NLP", kNLP).value("CV", kCV).export_values();
278   BindGlobalParams(&m);
279   BindSchema(&m);
280   BindStatistics(&m);
281   BindShardHeader(&m);
282   BindShardWriter(&m);
283   BindShardReader(&m);
284   BindShardIndexGenerator(&m);
285   BindShardSegment(&m);
286 }
287 }  // namespace mindrecord
288 }  // namespace mindspore
289 
290 namespace nlohmann {
291 namespace detail {
FromJsonImpl(const json & j)292 py::object FromJsonImpl(const json &j) {
293   if (j.is_null()) {
294     return py::none();
295   } else if (j.is_boolean()) {
296     return py::bool_(j.get<bool>());
297   } else if (j.is_number()) {
298     double number = j.get<double>();
299     if (fabs(number - std::floor(number)) < mindspore::mindrecord::kEpsilon) {
300       return py::int_(j.get<int64_t>());
301     } else {
302       return py::float_(number);
303     }
304   } else if (j.is_string()) {
305     return py::str(j.get<std::string>());
306   } else if (j.is_array()) {
307     py::list obj;
308     for (const auto &el : j) {
309       (void)obj.attr("append")(FromJsonImpl(el));
310     }
311     return std::move(obj);
312   } else {
313     py::dict obj;
314     for (json::const_iterator it = j.cbegin(); it != j.cend(); ++it) {
315       obj[py::str(it.key())] = FromJsonImpl(it.value());
316     }
317     return std::move(obj);
318   }
319 }
320 
ToJsonImpl(const py::handle & obj)321 json ToJsonImpl(const py::handle &obj) {
322   if (obj.is_none()) {
323     return nullptr;
324   }
325   if (py::isinstance<py::bool_>(obj)) {
326     return obj.cast<bool>();
327   }
328   if (py::isinstance<py::int_>(obj)) {
329     return obj.cast<int64_t>();
330   }
331   if (py::isinstance<py::float_>(obj)) {
332     return obj.cast<double>();
333   }
334   if (py::isinstance<py::str>(obj)) {
335     return obj.cast<std::string>();
336   }
337   if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
338     auto out = json::array();
339     for (const py::handle &value : obj) {
340       out.push_back(ToJsonImpl(value));
341     }
342     return out;
343   }
344   if (py::isinstance<py::dict>(obj)) {
345     auto out = json::object();
346     for (const py::handle &key : obj) {
347       out[py::str(key).cast<std::string>()] = ToJsonImpl(obj[key]);
348     }
349     return out;
350   }
351   MS_LOG(ERROR) << "Failed to convert Python object to json, object is: " << py::cast<std::string>(obj);
352   return json();
353 }
354 }  // namespace detail
355 
FromJson(const json & j)356 py::object adl_serializer<py::object>::FromJson(const json &j) { return detail::FromJsonImpl(j); }
357 
ToJson(json * j,const py::object & obj)358 void adl_serializer<py::object>::ToJson(json *j, const py::object &obj) {
359   *j = detail::ToJsonImpl(obj);
360 }  // namespace detail
361 }  // namespace nlohmann
362