1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <vector>
19 #include "utils/ms_utils.h"
20 #include "minddata/mindrecord/include/common/shard_utils.h"
21 #include "minddata/mindrecord/include/shard_error.h"
22 #include "minddata/mindrecord/include/shard_index_generator.h"
23 #include "minddata/mindrecord/include/shard_reader.h"
24 #include "minddata/mindrecord/include/shard_segment.h"
25 #include "minddata/mindrecord/include/shard_writer.h"
26 #include "nlohmann/json.hpp"
27 #include "pybind11/pybind11.h"
28 #include "pybind11/stl.h"
29 #include "utils/log_adapter.h"
30
31 namespace py = pybind11;
32
33 using mindspore::LogStream;
34 using mindspore::ExceptionType::NoExceptionType;
35 using mindspore::MsLogLevel::ERROR;
36
37 namespace mindspore {
38 namespace mindrecord {
39 #define THROW_IF_ERROR(s) \
40 do { \
41 Status rc = std::move(s); \
42 if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
43 } while (false)
44
BindSchema(py::module * m)45 void BindSchema(py::module *m) {
46 (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local())
47 .def_static("build",
48 [](const std::string &desc, const pybind11::handle &schema) {
49 json schema_json = nlohmann::detail::ToJsonImpl(schema);
50 return Schema::Build(std::move(desc), schema_json);
51 })
52 .def("get_desc", &Schema::GetDesc)
53 .def("get_schema_content",
54 [](Schema &s) {
55 json schema_json = s.GetSchema();
56 return nlohmann::detail::FromJsonImpl(schema_json);
57 })
58 .def("get_blob_fields", &Schema::GetBlobFields)
59 .def("get_schema_id", &Schema::GetSchemaID);
60 }
61
BindStatistics(const py::module * m)62 void BindStatistics(const py::module *m) {
63 (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local())
64 .def_static("build",
65 [](const std::string desc, const pybind11::handle &statistics) {
66 json statistics_json = nlohmann::detail::ToJsonImpl(statistics);
67 return Statistics::Build(std::move(desc), statistics_json);
68 })
69 .def("get_desc", &Statistics::GetDesc)
70 .def("get_statistics",
71 [](Statistics &s) {
72 json statistics_json = s.GetStatistics();
73 return nlohmann::detail::FromJsonImpl(statistics_json);
74 })
75 .def("get_statistics_id", &Statistics::GetStatisticsID);
76 }
77
BindShardHeader(const py::module * m)78 void BindShardHeader(const py::module *m) {
79 (void)py::class_<ShardHeader, std::shared_ptr<ShardHeader>>(*m, "ShardHeader", py::module_local())
80 .def(py::init<>())
81 .def("add_schema", &ShardHeader::AddSchema)
82 .def("add_statistics", &ShardHeader::AddStatistic)
83 .def("add_index_fields",
84 [](ShardHeader &s, const std::vector<std::string> &fields) {
85 THROW_IF_ERROR(s.AddIndexFields(fields));
86 return SUCCESS;
87 })
88 .def("get_meta", &ShardHeader::GetSchemas)
89 .def("get_statistics", &ShardHeader::GetStatistics)
90 .def("get_fields", &ShardHeader::GetFields)
91 .def("get_schema_by_id",
92 [](ShardHeader &s, int64_t schema_id) {
93 std::shared_ptr<Schema> schema_ptr;
94 THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr));
95 return schema_ptr;
96 })
97 .def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) {
98 std::shared_ptr<Statistics> statistics_ptr;
99 THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr));
100 return statistics_ptr;
101 });
102 }
103
BindShardWriter(py::module * m)104 void BindShardWriter(py::module *m) {
105 (void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local())
106 .def(py::init<>())
107 .def("open",
108 [](ShardWriter &s, const std::vector<std::string> &paths, bool append) {
109 THROW_IF_ERROR(s.Open(paths, append));
110 return SUCCESS;
111 })
112 .def("open_for_append",
113 [](ShardWriter &s, const std::string &path) {
114 THROW_IF_ERROR(s.OpenForAppend(path));
115 return SUCCESS;
116 })
117 .def("set_header_size",
118 [](ShardWriter &s, const uint64_t &header_size) {
119 THROW_IF_ERROR(s.SetHeaderSize(header_size));
120 return SUCCESS;
121 })
122 .def("set_page_size",
123 [](ShardWriter &s, const uint64_t &page_size) {
124 THROW_IF_ERROR(s.SetPageSize(page_size));
125 return SUCCESS;
126 })
127 .def("set_shard_header",
128 [](ShardWriter &s, std::shared_ptr<ShardHeader> header_data) {
129 THROW_IF_ERROR(s.SetShardHeader(header_data));
130 return SUCCESS;
131 })
132 .def("write_raw_data",
133 [](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
134 bool sign, bool parallel_writer) {
135 std::map<uint64_t, std::vector<json>> raw_data_json;
136 (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
137 [](const std::pair<uint64_t, std::vector<py::handle>> &p) {
138 auto &py_raw_data = p.second;
139 std::vector<json> json_raw_data;
140 (void)std::transform(
141 py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
142 [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
143 return std::make_pair(p.first, std::move(json_raw_data));
144 });
145 THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer));
146 return SUCCESS;
147 })
148 .def("commit", [](ShardWriter &s) {
149 THROW_IF_ERROR(s.Commit());
150 return SUCCESS;
151 });
152 }
153
BindShardReader(const py::module * m)154 void BindShardReader(const py::module *m) {
155 (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
156 .def(py::init<>())
157 .def("open",
158 [](ShardReader &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
159 const std::vector<std::string> &selected_columns,
160 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
161 THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
162 return SUCCESS;
163 })
164 .def("launch",
165 [](ShardReader &s) {
166 THROW_IF_ERROR(s.Launch(false));
167 return SUCCESS;
168 })
169 .def("get_header", &ShardReader::GetShardHeader)
170 .def("get_blob_fields", &ShardReader::GetBlobFields)
171 .def("get_next",
172 [](ShardReader &s) {
173 auto data = s.GetNext();
174 vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> res;
175 std::transform(data.begin(), data.end(), std::back_inserter(res),
176 [&s](const std::tuple<std::vector<uint8_t>, json> &item) {
177 auto &j = std::get<1>(item);
178 pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
179 auto blob_data_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
180 (void)s.UnCompressBlob(std::get<0>(item), &blob_data_ptr);
181 return std::make_tuple(*blob_data_ptr, std::move(obj));
182 });
183 return res;
184 })
185 .def("close", &ShardReader::Close);
186 }
187
BindShardIndexGenerator(const py::module * m)188 void BindShardIndexGenerator(const py::module *m) {
189 (void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local())
190 .def(py::init<const std::string &, bool>())
191 .def("build",
192 [](ShardIndexGenerator &s) {
193 THROW_IF_ERROR(s.Build());
194 return SUCCESS;
195 })
196 .def("write_to_db", [](ShardIndexGenerator &s) {
197 THROW_IF_ERROR(s.WriteToDatabase());
198 return SUCCESS;
199 });
200 }
201
BindShardSegment(py::module * m)202 void BindShardSegment(py::module *m) {
203 (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
204 .def(py::init<>())
205 .def("open",
206 [](ShardSegment &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
207 const std::vector<std::string> &selected_columns,
208 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
209 THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
210 return SUCCESS;
211 })
212 .def("get_category_fields",
213 [](ShardSegment &s) {
214 auto fields_ptr = std::make_shared<vector<std::string>>();
215 THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr));
216 return *fields_ptr;
217 })
218 .def("set_category_field",
219 [](ShardSegment &s, const std::string &category_field) {
220 THROW_IF_ERROR(s.SetCategoryField(category_field));
221 return SUCCESS;
222 })
223 .def("read_category_info",
224 [](ShardSegment &s) {
225 std::shared_ptr<std::string> category_ptr;
226 THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr));
227 return *category_ptr;
228 })
229 .def("read_at_page_by_id",
230 [](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
231 auto pages_load_ptr = std::make_shared<PAGES_LOAD>();
232 auto pages_ptr = std::make_shared<PAGES>();
233 THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr));
234 (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
235 [](const std::tuple<std::vector<uint8_t>, json> &item) {
236 auto &j = std::get<1>(item);
237 pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
238 return std::make_tuple(std::get<0>(item), std::move(obj));
239 });
240 return *pages_load_ptr;
241 })
242 .def("read_at_page_by_name",
243 [](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
244 auto pages_load_ptr = std::make_shared<PAGES_LOAD>();
245 auto pages_ptr = std::make_shared<PAGES>();
246 THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr));
247 (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
248 [](const std::tuple<std::vector<uint8_t>, json> &item) {
249 auto &j = std::get<1>(item);
250 pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
251 return std::make_tuple(std::get<0>(item), std::move(obj));
252 });
253 return *pages_load_ptr;
254 })
255 .def("get_header", &ShardSegment::GetShardHeader)
256 .def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); });
257 }
258
BindGlobalParams(py::module * m)259 void BindGlobalParams(py::module *m) {
260 (*m).attr("MIN_HEADER_SIZE") = kMinHeaderSize;
261 (*m).attr("MAX_HEADER_SIZE") = kMaxHeaderSize;
262 (*m).attr("MIN_PAGE_SIZE") = kMinPageSize;
263 (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize;
264 (*m).attr("MIN_SHARD_COUNT") = kMinShardCount;
265 (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount;
266 (*m).attr("MAX_FILE_COUNT") = kMaxFileCount;
267 (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount;
268 (void)(*m).def("get_max_thread_num", &GetMaxThreadNum);
269 }
270
PYBIND11_MODULE(_c_mindrecord,m)271 PYBIND11_MODULE(_c_mindrecord, m) {
272 m.doc() = "pybind11 mindrecord plugin"; // optional module docstring
273 (void)py::enum_<MSRStatus>(m, "MSRStatus", py::module_local())
274 .value("SUCCESS", SUCCESS)
275 .value("FAILED", FAILED)
276 .export_values();
277 (void)py::enum_<ShardType>(m, "ShardType", py::module_local()).value("NLP", kNLP).value("CV", kCV).export_values();
278 BindGlobalParams(&m);
279 BindSchema(&m);
280 BindStatistics(&m);
281 BindShardHeader(&m);
282 BindShardWriter(&m);
283 BindShardReader(&m);
284 BindShardIndexGenerator(&m);
285 BindShardSegment(&m);
286 }
287 } // namespace mindrecord
288 } // namespace mindspore
289
290 namespace nlohmann {
291 namespace detail {
FromJsonImpl(const json & j)292 py::object FromJsonImpl(const json &j) {
293 if (j.is_null()) {
294 return py::none();
295 } else if (j.is_boolean()) {
296 return py::bool_(j.get<bool>());
297 } else if (j.is_number()) {
298 double number = j.get<double>();
299 if (fabs(number - std::floor(number)) < mindspore::mindrecord::kEpsilon) {
300 return py::int_(j.get<int64_t>());
301 } else {
302 return py::float_(number);
303 }
304 } else if (j.is_string()) {
305 return py::str(j.get<std::string>());
306 } else if (j.is_array()) {
307 py::list obj;
308 for (const auto &el : j) {
309 (void)obj.attr("append")(FromJsonImpl(el));
310 }
311 return std::move(obj);
312 } else {
313 py::dict obj;
314 for (json::const_iterator it = j.cbegin(); it != j.cend(); ++it) {
315 obj[py::str(it.key())] = FromJsonImpl(it.value());
316 }
317 return std::move(obj);
318 }
319 }
320
ToJsonImpl(const py::handle & obj)321 json ToJsonImpl(const py::handle &obj) {
322 if (obj.is_none()) {
323 return nullptr;
324 }
325 if (py::isinstance<py::bool_>(obj)) {
326 return obj.cast<bool>();
327 }
328 if (py::isinstance<py::int_>(obj)) {
329 return obj.cast<int64_t>();
330 }
331 if (py::isinstance<py::float_>(obj)) {
332 return obj.cast<double>();
333 }
334 if (py::isinstance<py::str>(obj)) {
335 return obj.cast<std::string>();
336 }
337 if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
338 auto out = json::array();
339 for (const py::handle &value : obj) {
340 out.push_back(ToJsonImpl(value));
341 }
342 return out;
343 }
344 if (py::isinstance<py::dict>(obj)) {
345 auto out = json::object();
346 for (const py::handle &key : obj) {
347 out[py::str(key).cast<std::string>()] = ToJsonImpl(obj[key]);
348 }
349 return out;
350 }
351 MS_LOG(ERROR) << "Failed to convert Python object to json, object is: " << py::cast<std::string>(obj);
352 return json();
353 }
354 } // namespace detail
355
FromJson(const json & j)356 py::object adl_serializer<py::object>::FromJson(const json &j) { return detail::FromJsonImpl(j); }
357
ToJson(json * j,const py::object & obj)358 void adl_serializer<py::object>::ToJson(json *j, const py::object &obj) {
359 *j = detail::ToJsonImpl(obj);
360 } // namespace detail
361 } // namespace nlohmann
362