1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <vector>
19
20 #include "utils/ms_utils.h"
21 #include "minddata/dataset/util/md_log_adapter.h"
22 #include "minddata/mindrecord/include/common/log_adapter.h"
23 #include "minddata/mindrecord/include/common/shard_utils.h"
24 #include "minddata/mindrecord/include/shard_error.h"
25 #include "minddata/mindrecord/include/shard_index_generator.h"
26 #include "minddata/mindrecord/include/shard_reader.h"
27 #include "minddata/mindrecord/include/shard_segment.h"
28 #include "minddata/mindrecord/include/shard_writer.h"
29 #include "nlohmann/json.hpp"
30 #include "pybind11/pybind11.h"
31 #include "pybind11/stl.h"
32
33 namespace py = pybind11;
34 using mindspore::dataset::MDLogAdapter;
35
36 namespace mindspore {
37 namespace mindrecord {
38 #define THROW_IF_ERROR(s) \
39 do { \
40 Status rc = std::move(s); \
41 if (rc.IsError()) throw std::runtime_error(MDLogAdapter::Apply(&rc).ToString()); \
42 } while (false)
43
BindSchema(py::module * m)44 void BindSchema(py::module *m) {
45 (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local())
46 .def_static("build",
47 [](std::string desc, const pybind11::handle &schema) {
48 json schema_json = nlohmann::detail::ToJsonImpl(schema);
49 return Schema::Build(std::move(desc), schema_json);
50 })
51 .def("get_desc", &Schema::GetDesc)
52 .def("get_schema_content",
53 [](Schema &s) {
54 json schema_json = s.GetSchema();
55 return nlohmann::detail::FromJsonImpl(schema_json);
56 })
57 .def("get_blob_fields", &Schema::GetBlobFields)
58 .def("get_schema_id", &Schema::GetSchemaID);
59 }
60
BindStatistics(const py::module * m)61 void BindStatistics(const py::module *m) {
62 (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local())
63 .def_static("build",
64 [](std::string desc, const pybind11::handle &statistics) {
65 json statistics_json = nlohmann::detail::ToJsonImpl(statistics);
66 return Statistics::Build(std::move(desc), statistics_json);
67 })
68 .def("get_desc", &Statistics::GetDesc)
69 .def("get_statistics",
70 [](Statistics &s) {
71 json statistics_json = s.GetStatistics();
72 return nlohmann::detail::FromJsonImpl(statistics_json);
73 })
74 .def("get_statistics_id", &Statistics::GetStatisticsID);
75 }
76
BindShardHeader(const py::module * m)77 void BindShardHeader(const py::module *m) {
78 (void)py::class_<ShardHeader, std::shared_ptr<ShardHeader>>(*m, "ShardHeader", py::module_local())
79 .def(py::init<>())
80 .def("add_schema", &ShardHeader::AddSchema)
81 .def("add_statistics", &ShardHeader::AddStatistic)
82 .def("add_index_fields",
83 [](ShardHeader &s, const std::vector<std::string> &fields) {
84 THROW_IF_ERROR(s.AddIndexFields(fields));
85 return SUCCESS;
86 })
87 .def("get_meta", &ShardHeader::GetSchemas)
88 .def("get_statistics", &ShardHeader::GetStatistics)
89 .def("get_fields", &ShardHeader::GetFields)
90 .def("get_schema_by_id",
91 [](ShardHeader &s, int64_t schema_id) {
92 std::shared_ptr<Schema> schema_ptr;
93 THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr));
94 return schema_ptr;
95 })
96 .def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) {
97 std::shared_ptr<Statistics> statistics_ptr;
98 THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr));
99 return statistics_ptr;
100 });
101 }
102
BindShardWriter(py::module * m)103 void BindShardWriter(py::module *m) {
104 (void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local())
105 .def(py::init<>())
106 .def("open",
107 [](ShardWriter &s, const std::vector<std::string> &paths, bool append, bool overwrite) {
108 THROW_IF_ERROR(s.Open(paths, append, overwrite));
109 return SUCCESS;
110 })
111 .def("open_for_append",
112 [](ShardWriter &s, const std::string &path) {
113 THROW_IF_ERROR(s.OpenForAppend(path));
114 return SUCCESS;
115 })
116 .def("set_header_size",
117 [](ShardWriter &s, const uint64_t &header_size) {
118 THROW_IF_ERROR(s.SetHeaderSize(header_size));
119 return SUCCESS;
120 })
121 .def("set_page_size",
122 [](ShardWriter &s, const uint64_t &page_size) {
123 THROW_IF_ERROR(s.SetPageSize(page_size));
124 return SUCCESS;
125 })
126 .def("set_shard_header",
127 [](ShardWriter &s, std::shared_ptr<ShardHeader> header_data) {
128 THROW_IF_ERROR(s.SetShardHeader(header_data));
129 return SUCCESS;
130 })
131 .def("write_raw_data",
132 [](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<py::bytes> &blob_data,
133 bool sign, bool parallel_writer) {
134 // convert the raw_data from dict to json
135 std::map<uint64_t, std::vector<json>> raw_data_json;
136 (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
137 [](const std::pair<uint64_t, std::vector<py::handle>> &p) {
138 auto &py_raw_data = p.second;
139 std::vector<json> json_raw_data;
140 (void)std::transform(
141 py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
142 [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
143 return std::make_pair(p.first, std::move(json_raw_data));
144 });
145
146 // parallel convert blob_data from vector<py::bytes> to vector<vector<uint8_t>>
147 int32_t parallel_convert = kParallelConvert;
148 if (parallel_convert > blob_data.size()) {
149 parallel_convert = blob_data.size();
150 }
151 parallel_convert = parallel_convert != 0 ? parallel_convert : 1;
152 std::vector<std::thread> thread_set(parallel_convert);
153 vector<vector<uint8_t>> vector_blob_data(blob_data.size());
154 uint32_t step = uint32_t(blob_data.size() / parallel_convert);
155 if (blob_data.size() % parallel_convert != 0) {
156 step = step + 1;
157 }
158 for (int x = 0; x < parallel_convert; ++x) {
159 uint32_t start = x * step;
160 uint32_t end = ((x + 1) * step) < blob_data.size() ? ((x + 1) * step) : blob_data.size();
161 thread_set[x] = std::thread([&vector_blob_data, &blob_data, start, end]() {
162 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
163 pthread_setname_np(
164 pthread_self(),
165 std::string("ParallelConvert" + std::to_string(start) + ":" + std::to_string(end)).c_str());
166 #endif
167 for (auto i = start; i < end; i++) {
168 char *buffer = nullptr;
169 ssize_t length = 0;
170 if (PYBIND11_BYTES_AS_STRING_AND_SIZE(blob_data[i].ptr(), &buffer, &length) != 0) {
171 MS_LOG(ERROR) << "Unable to extract bytes contents!";
172 return FAILED;
173 }
174 vector<uint8_t> blob_data_item(length);
175 if (length < SECUREC_MEM_MAX_LEN) {
176 int ret_code = memcpy_s(&blob_data_item[0], length, buffer, length);
177 if (ret_code != EOK) {
178 MS_LOG(ERROR) << "memcpy_s failed for py::bytes to vector<uint8_t>.";
179 return FAILED;
180 }
181 } else {
182 auto ret_code = std::memcpy(&blob_data_item[0], buffer, length);
183 if (ret_code != &blob_data_item[0]) {
184 MS_LOG(ERROR) << "memcpy failed for py::bytes to vector<uint8_t>.";
185 return FAILED;
186 }
187 }
188 vector_blob_data[i] = blob_data_item;
189 }
190 return SUCCESS;
191 });
192 }
193
194 // wait for the threads join
195 for (int x = 0; x < parallel_convert; ++x) {
196 thread_set[x].join();
197 }
198 THROW_IF_ERROR(s.WriteRawData(raw_data_json, vector_blob_data, sign, parallel_writer));
199 return SUCCESS;
200 })
201 .def("commit", [](ShardWriter &s) {
202 THROW_IF_ERROR(s.Commit());
203 return SUCCESS;
204 });
205 }
206
BindShardReader(const py::module * m)207 void BindShardReader(const py::module *m) {
208 (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
209 .def(py::init<>())
210 .def("open",
211 [](ShardReader &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
212 const std::vector<std::string> &selected_columns,
213 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
214 THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
215 return SUCCESS;
216 })
217 .def("launch",
218 [](ShardReader &s) {
219 THROW_IF_ERROR(s.Launch(false));
220 return SUCCESS;
221 })
222 .def("get_header", &ShardReader::GetShardHeader)
223 .def("get_blob_fields", &ShardReader::GetBlobFields)
224 .def("get_next",
225 [](ShardReader &s) {
226 auto data = s.GetNext();
227 std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, pybind11::object>> res;
228 std::transform(data.begin(), data.end(), std::back_inserter(res),
229 [&s](const std::tuple<std::map<std::string, std::vector<uint8_t>>, json> &item) {
230 auto &j = std::get<1>(item);
231 pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
232 return std::make_tuple(std::move(std::get<0>(item)), std::move(obj));
233 });
234 return res;
235 })
236 .def("close", &ShardReader::Close)
237 .def("len", &ShardReader::GetNumRows);
238 }
239
BindShardIndexGenerator(const py::module * m)240 void BindShardIndexGenerator(const py::module *m) {
241 (void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local())
242 .def(py::init<const std::string &, bool>())
243 .def("build",
244 [](ShardIndexGenerator &s) {
245 THROW_IF_ERROR(s.Build());
246 return SUCCESS;
247 })
248 .def("write_to_db", [](ShardIndexGenerator &s) {
249 THROW_IF_ERROR(s.WriteToDatabase());
250 return SUCCESS;
251 });
252 }
253
BindShardSegment(py::module * m)254 void BindShardSegment(py::module *m) {
255 (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
256 .def(py::init<>())
257 .def("open",
258 [](ShardSegment &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
259 const std::vector<std::string> &selected_columns,
260 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
261 THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
262 return SUCCESS;
263 })
264 .def("get_category_fields",
265 [](ShardSegment &s) {
266 auto fields_ptr = std::make_shared<vector<std::string>>();
267 THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr));
268 return *fields_ptr;
269 })
270 .def("set_category_field",
271 [](ShardSegment &s, const std::string &category_field) {
272 THROW_IF_ERROR(s.SetCategoryField(category_field));
273 return SUCCESS;
274 })
275 .def("read_category_info",
276 [](ShardSegment &s) {
277 std::shared_ptr<std::string> category_ptr;
278 THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr));
279 return *category_ptr;
280 })
281 .def("read_at_page_by_id",
282 [](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
283 auto pages_load_ptr = std::make_shared<PAGES_LOAD_WITH_BLOBS>();
284 auto pages_ptr = std::make_shared<PAGES_WITH_BLOBS>();
285 THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr));
286 (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
287 [](const std::tuple<std::map<std::string, std::vector<uint8_t>>, json> &item) {
288 auto &j = std::get<1>(item);
289 pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
290 return std::make_tuple(std::get<0>(item), std::move(obj));
291 });
292 return *pages_load_ptr;
293 })
294 .def("read_at_page_by_name",
295 [](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
296 auto pages_load_ptr = std::make_shared<PAGES_LOAD_WITH_BLOBS>();
297 auto pages_ptr = std::make_shared<PAGES_WITH_BLOBS>();
298 THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr));
299 (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
300 [](const std::tuple<std::map<std::string, std::vector<uint8_t>>, json> &item) {
301 auto &j = std::get<1>(item);
302 pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
303 return std::make_tuple(std::get<0>(item), std::move(obj));
304 });
305 return *pages_load_ptr;
306 })
307 .def("get_header", &ShardSegment::GetShardHeader)
308 .def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); });
309 }
310
BindGlobalParams(py::module * m)311 void BindGlobalParams(py::module *m) {
312 (*m).attr("MIN_HEADER_SIZE") = kMinHeaderSize;
313 (*m).attr("MAX_HEADER_SIZE") = kMaxHeaderSize;
314 (*m).attr("MIN_PAGE_SIZE") = kMinPageSize;
315 (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize;
316 (*m).attr("MIN_SHARD_COUNT") = kMinShardCount;
317 (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount;
318 (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount;
319 (*m).attr("MIN_FILE_SIZE") = kMinFileSize;
320 (void)(*m).def("get_max_thread_num", &GetMaxThreadNum);
321 }
322
PYBIND11_MODULE(_c_mindrecord,m)323 PYBIND11_MODULE(_c_mindrecord, m) {
324 m.doc() = "pybind11 mindrecord plugin"; // optional module docstring
325 (void)py::enum_<MSRStatus>(m, "MSRStatus", py::module_local())
326 .value("SUCCESS", SUCCESS)
327 .value("FAILED", FAILED)
328 .export_values();
329 (void)py::enum_<ShardType>(m, "ShardType", py::module_local()).value("NLP", kNLP).value("CV", kCV).export_values();
330 BindGlobalParams(&m);
331 BindSchema(&m);
332 BindStatistics(&m);
333 BindShardHeader(&m);
334 BindShardWriter(&m);
335 BindShardReader(&m);
336 BindShardIndexGenerator(&m);
337 BindShardSegment(&m);
338 }
339 } // namespace mindrecord
340 } // namespace mindspore
341
342 namespace nlohmann {
343 namespace detail {
FromJsonImpl(const json & j)344 py::object FromJsonImpl(const json &j) {
345 if (j.is_null()) {
346 return py::none();
347 } else if (j.is_boolean()) {
348 return py::bool_(j.get<bool>());
349 } else if (j.is_number()) {
350 double number = j.get<double>();
351 if (fabs(number - std::floor(number)) < mindspore::mindrecord::kEpsilon) {
352 return py::int_(j.get<int64_t>());
353 } else {
354 return py::float_(number);
355 }
356 } else if (j.is_string()) {
357 return py::str(j.get<std::string>());
358 } else if (j.is_array()) {
359 py::list obj;
360 for (const auto &el : j) {
361 (void)obj.attr("append")(FromJsonImpl(el));
362 }
363 return std::move(obj);
364 } else {
365 py::dict obj;
366 for (json::const_iterator it = j.cbegin(); it != j.cend(); ++it) {
367 obj[py::str(it.key())] = FromJsonImpl(it.value());
368 }
369 return std::move(obj);
370 }
371 }
372
ToJsonImpl(const py::handle & obj)373 json ToJsonImpl(const py::handle &obj) {
374 if (obj.is_none()) {
375 return nullptr;
376 }
377 if (py::isinstance<py::bool_>(obj)) {
378 return obj.cast<bool>();
379 }
380 if (py::isinstance<py::int_>(obj)) {
381 return obj.cast<int64_t>();
382 }
383 if (py::isinstance<py::float_>(obj)) {
384 return obj.cast<double>();
385 }
386 if (py::isinstance<py::str>(obj)) {
387 return obj.cast<std::string>();
388 }
389 if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
390 auto out = json::array();
391 for (const py::handle &value : obj) {
392 out.push_back(ToJsonImpl(value));
393 }
394 return out;
395 }
396 if (py::isinstance<py::dict>(obj)) {
397 auto out = json::object();
398 for (const py::handle &key : obj) {
399 out[py::str(key).cast<std::string>()] = ToJsonImpl(obj[key]);
400 }
401 return out;
402 }
403 MS_LOG(ERROR) << "[Internal ERROR] Failed to convert Python object: " << py::cast<std::string>(obj)
404 << " to type json.";
405 return json();
406 }
407 } // namespace detail
408
FromJson(const json & j)409 py::object adl_serializer<py::object>::FromJson(const json &j) { return detail::FromJsonImpl(j); }
410
ToJson(json * j,const py::object & obj)411 void adl_serializer<py::object>::ToJson(json *j, const py::object &obj) {
412 *j = detail::ToJsonImpl(obj);
413 } // namespace detail
414 } // namespace nlohmann
415