• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <vector>
19 
20 #include "utils/ms_utils.h"
21 #include "minddata/dataset/util/md_log_adapter.h"
22 #include "minddata/mindrecord/include/common/log_adapter.h"
23 #include "minddata/mindrecord/include/common/shard_utils.h"
24 #include "minddata/mindrecord/include/shard_error.h"
25 #include "minddata/mindrecord/include/shard_index_generator.h"
26 #include "minddata/mindrecord/include/shard_reader.h"
27 #include "minddata/mindrecord/include/shard_segment.h"
28 #include "minddata/mindrecord/include/shard_writer.h"
29 #include "nlohmann/json.hpp"
30 #include "pybind11/pybind11.h"
31 #include "pybind11/stl.h"
32 
33 namespace py = pybind11;
34 using mindspore::dataset::MDLogAdapter;
35 
36 namespace mindspore {
37 namespace mindrecord {
38 #define THROW_IF_ERROR(s)                                                            \
39   do {                                                                               \
40     Status rc = std::move(s);                                                        \
41     if (rc.IsError()) throw std::runtime_error(MDLogAdapter::Apply(&rc).ToString()); \
42   } while (false)
43 
BindSchema(py::module * m)44 void BindSchema(py::module *m) {
45   (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local())
46     .def_static("build",
47                 [](std::string desc, const pybind11::handle &schema) {
48                   json schema_json = nlohmann::detail::ToJsonImpl(schema);
49                   return Schema::Build(std::move(desc), schema_json);
50                 })
51     .def("get_desc", &Schema::GetDesc)
52     .def("get_schema_content",
53          [](Schema &s) {
54            json schema_json = s.GetSchema();
55            return nlohmann::detail::FromJsonImpl(schema_json);
56          })
57     .def("get_blob_fields", &Schema::GetBlobFields)
58     .def("get_schema_id", &Schema::GetSchemaID);
59 }
60 
BindStatistics(const py::module * m)61 void BindStatistics(const py::module *m) {
62   (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local())
63     .def_static("build",
64                 [](std::string desc, const pybind11::handle &statistics) {
65                   json statistics_json = nlohmann::detail::ToJsonImpl(statistics);
66                   return Statistics::Build(std::move(desc), statistics_json);
67                 })
68     .def("get_desc", &Statistics::GetDesc)
69     .def("get_statistics",
70          [](Statistics &s) {
71            json statistics_json = s.GetStatistics();
72            return nlohmann::detail::FromJsonImpl(statistics_json);
73          })
74     .def("get_statistics_id", &Statistics::GetStatisticsID);
75 }
76 
BindShardHeader(const py::module * m)77 void BindShardHeader(const py::module *m) {
78   (void)py::class_<ShardHeader, std::shared_ptr<ShardHeader>>(*m, "ShardHeader", py::module_local())
79     .def(py::init<>())
80     .def("add_schema", &ShardHeader::AddSchema)
81     .def("add_statistics", &ShardHeader::AddStatistic)
82     .def("add_index_fields",
83          [](ShardHeader &s, const std::vector<std::string> &fields) {
84            THROW_IF_ERROR(s.AddIndexFields(fields));
85            return SUCCESS;
86          })
87     .def("get_meta", &ShardHeader::GetSchemas)
88     .def("get_statistics", &ShardHeader::GetStatistics)
89     .def("get_fields", &ShardHeader::GetFields)
90     .def("get_schema_by_id",
91          [](ShardHeader &s, int64_t schema_id) {
92            std::shared_ptr<Schema> schema_ptr;
93            THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr));
94            return schema_ptr;
95          })
96     .def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) {
97       std::shared_ptr<Statistics> statistics_ptr;
98       THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr));
99       return statistics_ptr;
100     });
101 }
102 
BindShardWriter(py::module * m)103 void BindShardWriter(py::module *m) {
104   (void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local())
105     .def(py::init<>())
106     .def("open",
107          [](ShardWriter &s, const std::vector<std::string> &paths, bool append, bool overwrite) {
108            THROW_IF_ERROR(s.Open(paths, append, overwrite));
109            return SUCCESS;
110          })
111     .def("open_for_append",
112          [](ShardWriter &s, const std::string &path) {
113            THROW_IF_ERROR(s.OpenForAppend(path));
114            return SUCCESS;
115          })
116     .def("set_header_size",
117          [](ShardWriter &s, const uint64_t &header_size) {
118            THROW_IF_ERROR(s.SetHeaderSize(header_size));
119            return SUCCESS;
120          })
121     .def("set_page_size",
122          [](ShardWriter &s, const uint64_t &page_size) {
123            THROW_IF_ERROR(s.SetPageSize(page_size));
124            return SUCCESS;
125          })
126     .def("set_shard_header",
127          [](ShardWriter &s, std::shared_ptr<ShardHeader> header_data) {
128            THROW_IF_ERROR(s.SetShardHeader(header_data));
129            return SUCCESS;
130          })
131     .def("write_raw_data",
132          [](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<py::bytes> &blob_data,
133             bool sign, bool parallel_writer) {
134            // convert the raw_data from dict to json
135            std::map<uint64_t, std::vector<json>> raw_data_json;
136            (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
137                                 [](const std::pair<uint64_t, std::vector<py::handle>> &p) {
138                                   auto &py_raw_data = p.second;
139                                   std::vector<json> json_raw_data;
140                                   (void)std::transform(
141                                     py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
142                                     [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
143                                   return std::make_pair(p.first, std::move(json_raw_data));
144                                 });
145 
146            // parallel convert blob_data from vector<py::bytes> to vector<vector<uint8_t>>
147            int32_t parallel_convert = kParallelConvert;
148            if (parallel_convert > blob_data.size()) {
149              parallel_convert = blob_data.size();
150            }
151            parallel_convert = parallel_convert != 0 ? parallel_convert : 1;
152            std::vector<std::thread> thread_set(parallel_convert);
153            vector<vector<uint8_t>> vector_blob_data(blob_data.size());
154            uint32_t step = uint32_t(blob_data.size() / parallel_convert);
155            if (blob_data.size() % parallel_convert != 0) {
156              step = step + 1;
157            }
158            for (int x = 0; x < parallel_convert; ++x) {
159              uint32_t start = x * step;
160              uint32_t end = ((x + 1) * step) < blob_data.size() ? ((x + 1) * step) : blob_data.size();
161              thread_set[x] = std::thread([&vector_blob_data, &blob_data, start, end]() {
162 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
163                pthread_setname_np(
164                  pthread_self(),
165                  std::string("ParallelConvert" + std::to_string(start) + ":" + std::to_string(end)).c_str());
166 #endif
167                for (auto i = start; i < end; i++) {
168                  char *buffer = nullptr;
169                  ssize_t length = 0;
170                  if (PYBIND11_BYTES_AS_STRING_AND_SIZE(blob_data[i].ptr(), &buffer, &length) != 0) {
171                    MS_LOG(ERROR) << "Unable to extract bytes contents!";
172                    return FAILED;
173                  }
174                  vector<uint8_t> blob_data_item(length);
175                  if (length < SECUREC_MEM_MAX_LEN) {
176                    int ret_code = memcpy_s(&blob_data_item[0], length, buffer, length);
177                    if (ret_code != EOK) {
178                      MS_LOG(ERROR) << "memcpy_s failed for py::bytes to vector<uint8_t>.";
179                      return FAILED;
180                    }
181                  } else {
182                    auto ret_code = std::memcpy(&blob_data_item[0], buffer, length);
183                    if (ret_code != &blob_data_item[0]) {
184                      MS_LOG(ERROR) << "memcpy failed for py::bytes to vector<uint8_t>.";
185                      return FAILED;
186                    }
187                  }
188                  vector_blob_data[i] = blob_data_item;
189                }
190                return SUCCESS;
191              });
192            }
193 
194            // wait for the threads join
195            for (int x = 0; x < parallel_convert; ++x) {
196              thread_set[x].join();
197            }
198            THROW_IF_ERROR(s.WriteRawData(raw_data_json, vector_blob_data, sign, parallel_writer));
199            return SUCCESS;
200          })
201     .def("commit", [](ShardWriter &s) {
202       THROW_IF_ERROR(s.Commit());
203       return SUCCESS;
204     });
205 }
206 
BindShardReader(const py::module * m)207 void BindShardReader(const py::module *m) {
208   (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
209     .def(py::init<>())
210     .def("open",
211          [](ShardReader &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
212             const std::vector<std::string> &selected_columns,
213             const std::vector<std::shared_ptr<ShardOperator>> &operators) {
214            THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
215            return SUCCESS;
216          })
217     .def("launch",
218          [](ShardReader &s) {
219            THROW_IF_ERROR(s.Launch(false));
220            return SUCCESS;
221          })
222     .def("get_header", &ShardReader::GetShardHeader)
223     .def("get_blob_fields", &ShardReader::GetBlobFields)
224     .def("get_next",
225          [](ShardReader &s) {
226            auto data = s.GetNext();
227            std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, pybind11::object>> res;
228            std::transform(data.begin(), data.end(), std::back_inserter(res),
229                           [&s](const std::tuple<std::map<std::string, std::vector<uint8_t>>, json> &item) {
230                             auto &j = std::get<1>(item);
231                             pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
232                             return std::make_tuple(std::move(std::get<0>(item)), std::move(obj));
233                           });
234            return res;
235          })
236     .def("close", &ShardReader::Close)
237     .def("len", &ShardReader::GetNumRows);
238 }
239 
BindShardIndexGenerator(const py::module * m)240 void BindShardIndexGenerator(const py::module *m) {
241   (void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local())
242     .def(py::init<const std::string &, bool>())
243     .def("build",
244          [](ShardIndexGenerator &s) {
245            THROW_IF_ERROR(s.Build());
246            return SUCCESS;
247          })
248     .def("write_to_db", [](ShardIndexGenerator &s) {
249       THROW_IF_ERROR(s.WriteToDatabase());
250       return SUCCESS;
251     });
252 }
253 
BindShardSegment(py::module * m)254 void BindShardSegment(py::module *m) {
255   (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
256     .def(py::init<>())
257     .def("open",
258          [](ShardSegment &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
259             const std::vector<std::string> &selected_columns,
260             const std::vector<std::shared_ptr<ShardOperator>> &operators) {
261            THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
262            return SUCCESS;
263          })
264     .def("get_category_fields",
265          [](ShardSegment &s) {
266            auto fields_ptr = std::make_shared<vector<std::string>>();
267            THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr));
268            return *fields_ptr;
269          })
270     .def("set_category_field",
271          [](ShardSegment &s, const std::string &category_field) {
272            THROW_IF_ERROR(s.SetCategoryField(category_field));
273            return SUCCESS;
274          })
275     .def("read_category_info",
276          [](ShardSegment &s) {
277            std::shared_ptr<std::string> category_ptr;
278            THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr));
279            return *category_ptr;
280          })
281     .def("read_at_page_by_id",
282          [](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
283            auto pages_load_ptr = std::make_shared<PAGES_LOAD_WITH_BLOBS>();
284            auto pages_ptr = std::make_shared<PAGES_WITH_BLOBS>();
285            THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr));
286            (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
287                                 [](const std::tuple<std::map<std::string, std::vector<uint8_t>>, json> &item) {
288                                   auto &j = std::get<1>(item);
289                                   pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
290                                   return std::make_tuple(std::get<0>(item), std::move(obj));
291                                 });
292            return *pages_load_ptr;
293          })
294     .def("read_at_page_by_name",
295          [](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
296            auto pages_load_ptr = std::make_shared<PAGES_LOAD_WITH_BLOBS>();
297            auto pages_ptr = std::make_shared<PAGES_WITH_BLOBS>();
298            THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr));
299            (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
300                                 [](const std::tuple<std::map<std::string, std::vector<uint8_t>>, json> &item) {
301                                   auto &j = std::get<1>(item);
302                                   pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
303                                   return std::make_tuple(std::get<0>(item), std::move(obj));
304                                 });
305            return *pages_load_ptr;
306          })
307     .def("get_header", &ShardSegment::GetShardHeader)
308     .def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); });
309 }
310 
BindGlobalParams(py::module * m)311 void BindGlobalParams(py::module *m) {
312   (*m).attr("MIN_HEADER_SIZE") = kMinHeaderSize;
313   (*m).attr("MAX_HEADER_SIZE") = kMaxHeaderSize;
314   (*m).attr("MIN_PAGE_SIZE") = kMinPageSize;
315   (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize;
316   (*m).attr("MIN_SHARD_COUNT") = kMinShardCount;
317   (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount;
318   (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount;
319   (*m).attr("MIN_FILE_SIZE") = kMinFileSize;
320   (void)(*m).def("get_max_thread_num", &GetMaxThreadNum);
321 }
322 
PYBIND11_MODULE(_c_mindrecord,m)323 PYBIND11_MODULE(_c_mindrecord, m) {
324   m.doc() = "pybind11 mindrecord plugin";  // optional module docstring
325   (void)py::enum_<MSRStatus>(m, "MSRStatus", py::module_local())
326     .value("SUCCESS", SUCCESS)
327     .value("FAILED", FAILED)
328     .export_values();
329   (void)py::enum_<ShardType>(m, "ShardType", py::module_local()).value("NLP", kNLP).value("CV", kCV).export_values();
330   BindGlobalParams(&m);
331   BindSchema(&m);
332   BindStatistics(&m);
333   BindShardHeader(&m);
334   BindShardWriter(&m);
335   BindShardReader(&m);
336   BindShardIndexGenerator(&m);
337   BindShardSegment(&m);
338 }
339 }  // namespace mindrecord
340 }  // namespace mindspore
341 
342 namespace nlohmann {
343 namespace detail {
FromJsonImpl(const json & j)344 py::object FromJsonImpl(const json &j) {
345   if (j.is_null()) {
346     return py::none();
347   } else if (j.is_boolean()) {
348     return py::bool_(j.get<bool>());
349   } else if (j.is_number()) {
350     double number = j.get<double>();
351     if (fabs(number - std::floor(number)) < mindspore::mindrecord::kEpsilon) {
352       return py::int_(j.get<int64_t>());
353     } else {
354       return py::float_(number);
355     }
356   } else if (j.is_string()) {
357     return py::str(j.get<std::string>());
358   } else if (j.is_array()) {
359     py::list obj;
360     for (const auto &el : j) {
361       (void)obj.attr("append")(FromJsonImpl(el));
362     }
363     return std::move(obj);
364   } else {
365     py::dict obj;
366     for (json::const_iterator it = j.cbegin(); it != j.cend(); ++it) {
367       obj[py::str(it.key())] = FromJsonImpl(it.value());
368     }
369     return std::move(obj);
370   }
371 }
372 
ToJsonImpl(const py::handle & obj)373 json ToJsonImpl(const py::handle &obj) {
374   if (obj.is_none()) {
375     return nullptr;
376   }
377   if (py::isinstance<py::bool_>(obj)) {
378     return obj.cast<bool>();
379   }
380   if (py::isinstance<py::int_>(obj)) {
381     return obj.cast<int64_t>();
382   }
383   if (py::isinstance<py::float_>(obj)) {
384     return obj.cast<double>();
385   }
386   if (py::isinstance<py::str>(obj)) {
387     return obj.cast<std::string>();
388   }
389   if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
390     auto out = json::array();
391     for (const py::handle &value : obj) {
392       out.push_back(ToJsonImpl(value));
393     }
394     return out;
395   }
396   if (py::isinstance<py::dict>(obj)) {
397     auto out = json::object();
398     for (const py::handle &key : obj) {
399       out[py::str(key).cast<std::string>()] = ToJsonImpl(obj[key]);
400     }
401     return out;
402   }
403   MS_LOG(ERROR) << "[Internal ERROR] Failed to convert Python object: " << py::cast<std::string>(obj)
404                 << " to type json.";
405   return json();
406 }
407 }  // namespace detail
408 
FromJson(const json & j)409 py::object adl_serializer<py::object>::FromJson(const json &j) { return detail::FromJsonImpl(j); }
410 
ToJson(json * j,const py::object & obj)411 void adl_serializer<py::object>::ToJson(json *j, const py::object &obj) {
412   *j = detail::ToJsonImpl(obj);
413 }  // namespace detail
414 }  // namespace nlohmann
415