• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
19 
20 #include <sys/stat.h>
21 #include <unistd.h>
22 
23 #include <algorithm>
24 #include <map>
25 #include <memory>
26 #include <set>
27 #include <string>
28 #include <unordered_map>
29 #include <unordered_set>
30 #include <utility>
31 #include <vector>
32 
33 #include "include/api/dual_abi_helper.h"
34 #include "include/api/types.h"
35 #include "include/dataset/iterator.h"
36 #include "include/dataset/samplers.h"
37 #include "include/dataset/transforms.h"
38 
39 namespace mindspore {
40 namespace dataset {
41 
42 class Tensor;
43 class TensorShape;
44 class TreeAdapter;
45 class TreeAdapterLite;
46 class TreeGetters;
47 
48 class DatasetCache;
49 class DatasetNode;
50 
51 class Iterator;
52 
53 class TensorOperation;
54 class SchemaObj;
55 class SamplerObj;
56 
57 // Dataset classes (in alphabetical order)
58 class BatchDataset;
59 class MapDataset;
60 class ProjectDataset;
61 class ShuffleDataset;
62 class DSCallback;
63 
64 /// \class Dataset datasets.h
65 /// \brief A base class to represent a dataset in the data pipeline.
66 class Dataset : public std::enable_shared_from_this<Dataset> {
67  public:
68   // need friend class so they can access the children_ field
69   friend class Iterator;
70   friend class TransferNode;
71 
72   /// \brief Constructor
73   Dataset();
74 
75   /// \brief Destructor
76   ~Dataset() = default;
77 
78   /// \brief Gets the dataset size
79   /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
80   ///     dataset size at the expense of accuracy.
81   /// \return dataset size. If failed, return -1
82   int64_t GetDatasetSize(bool estimate = false);
83 
84   /// \brief Gets the output type
85   /// \return a vector of DataType. If failed, return an empty vector
86   std::vector<mindspore::DataType> GetOutputTypes();
87 
88   /// \brief Gets the output shape
89   /// \return a vector of TensorShape. If failed, return an empty vector
90   std::vector<std::vector<int64_t>> GetOutputShapes();
91 
92   /// \brief Gets the batch size
93   /// \return int64_t
94   int64_t GetBatchSize();
95 
96   /// \brief Gets the repeat count
97   /// \return int64_t
98   int64_t GetRepeatCount();
99 
100   /// \brief Gets the number of classes
101   /// \return number of classes. If failed, return -1
102   int64_t GetNumClasses();
103 
104   /// \brief Gets the column names
105   /// \return Names of the columns. If failed, return an empty vector
GetColumnNames()106   std::vector<std::string> GetColumnNames() { return VectorCharToString(GetColumnNamesCharIF()); }
107 
108   /// \brief Gets the class indexing
109   /// \return a map of ClassIndexing. If failed, return an empty map
GetClassIndexing()110   std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing() {
111     return ClassIndexCharToString(GetClassIndexingCharIF());
112   }
113 
114   /// \brief Setter function for runtime number of workers
115   /// \param[in] num_workers The number of threads in this operator
116   /// \return Shared pointer to the original object
117   std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);
118 
119   /// \brief Function to create an PullBasedIterator over the Dataset
120   /// \param[in] columns List of columns to be used to specify the order of columns
121   /// \return Shared pointer to the Iterator
122   std::shared_ptr<PullIterator> CreatePullBasedIterator(std::vector<std::vector<char>> columns = {});
123 
124   /// \brief Function to create an Iterator over the Dataset pipeline
125   /// \param[in] columns List of columns to be used to specify the order of columns
126   /// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
127   ///     An empty row is returned at the end of each epoch
128   /// \return Shared pointer to the Iterator
129   std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1) {
130     return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
131   }
132 
133   /// \brief Function to transfer data through a device.
134   /// \notes If device is Ascend, features of data will be transferred one by one. The limitation
135   ///     of data transmission per time is 256M.
136   /// \param[in] queue_name Channel name (default="", create new unique name).
137   /// \param[in] device_type Type of device (default="", get from MSContext).
138   /// \param[in] device_id id of device (default=1, get from MSContext).
139   /// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
140   /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
141   /// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
142   /// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
143   ///     of data or not(default=false).
144   /// \return Returns true if no error encountered else false.
145   bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0,
146                    int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
147                    bool create_data_info_queue = false) {
148     return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
149                              total_batches, create_data_info_queue);
150   }
151 
152   /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
153   /// \note Usage restrictions:
154   ///     1. Supported dataset formats: 'mindrecord' only
155   ///     2. To save the samples in order, set dataset's shuffle to false and num_files to 1.
156   ///     3. Before calling the function, do not use batch operator, repeat operator or data augmentation operators
157   ///        with random attribute in map operator.
158   ///     4. Mindrecord does not support bool, uint64, multi-dimensional uint8(drop dimension) nor
159   ///        multi-dimensional string.
160   /// \param[in] file_name Path to dataset file
161   /// \param[in] num_files Number of dataset files (default=1)
162   /// \param[in] file_type Dataset format (default="mindrecord")
163   /// \return Returns true if no error encountered else false
164   bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
165     return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
166   }
167 
168   /// \brief Function to create a BatchDataset
169   /// \notes Combines batch_size number of consecutive rows into batches
170   /// \param[in] batch_size The number of rows each batch is created with
171   /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
172   ///     batch. If true, and if there are less than batch_size rows
173   ///     available to make the last batch, then those rows will
174   ///     be dropped and not propagated to the next node
175   /// \return Shared pointer to the current BatchDataset
176   std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
177 
178   /// \brief Function to create a MapDataset
179   /// \notes Applies each operation in operations to this dataset
180   /// \param[in] operations Vector of raw pointers to TensorTransform objects to be applied on the dataset. Operations
181   ///     are applied in the order they appear in this list
182   /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
183   ///     operation as input. The size of this list must match the number of
184   ///     input columns expected by the first operator. The default input_columns
185   ///     is the first column
186   /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
187   ///     This parameter is mandatory if len(input_columns) != len(output_columns)
188   ///     The size of this list must match the number of output columns of the
189   ///     last operation. The default output_columns will have the same
190   ///     name as the input columns, i.e., the columns will be replaced
191   /// \param[in] project_columns A list of column names to project
192   /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
193   /// \return Shared pointer to the current MapDataset
194   std::shared_ptr<MapDataset> Map(std::vector<TensorTransform *> operations,
195                                   const std::vector<std::string> &input_columns = {},
196                                   const std::vector<std::string> &output_columns = {},
197                                   const std::vector<std::string> &project_columns = {},
198                                   const std::shared_ptr<DatasetCache> &cache = nullptr,
199                                   std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
200     std::vector<std::shared_ptr<TensorOperation>> transform_ops;
201     (void)std::transform(
202       operations.begin(), operations.end(), std::back_inserter(transform_ops),
203       [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
204     return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
205                                         VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
206                                         callbacks);
207   }
208 
209   /// \brief Function to create a MapDataset
210   /// \notes Applies each operation in operations to this dataset
211   /// \param[in] operations Vector of shared pointers to TensorTransform objects to be applied on the dataset.
212   ///     Operations are applied in the order they appear in this list
213   /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
214   ///     operation as input. The size of this list must match the number of
215   ///     input columns expected by the first operator. The default input_columns
216   ///     is the first column
217   /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
218   ///     This parameter is mandatory if len(input_columns) != len(output_columns)
219   ///     The size of this list must match the number of output columns of the
220   ///     last operation. The default output_columns will have the same
221   ///     name as the input columns, i.e., the columns will be replaced
222   /// \param[in] project_columns A list of column names to project
223   /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
224   /// \return Shared pointer to the current MapDataset
225   std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations,
226                                   const std::vector<std::string> &input_columns = {},
227                                   const std::vector<std::string> &output_columns = {},
228                                   const std::vector<std::string> &project_columns = {},
229                                   const std::shared_ptr<DatasetCache> &cache = nullptr,
230                                   std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
231     std::vector<std::shared_ptr<TensorOperation>> transform_ops;
232     (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
233                          [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
234                            return op != nullptr ? op->Parse() : nullptr;
235                          });
236     return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
237                                         VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
238                                         callbacks);
239   }
240 
241   /// \brief Function to create a MapDataset
242   /// \notes Applies each operation in operations to this dataset
243   /// \param[in] operations Vector of TensorTransform objects to be applied on the dataset. Operations are applied in
244   ///     the order they appear in this list
245   /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
246   ///     operation as input. The size of this list must match the number of
247   ///     input columns expected by the first operator. The default input_columns
248   ///     is the first column
249   /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
250   ///     This parameter is mandatory if len(input_columns) != len(output_columns)
251   ///     The size of this list must match the number of output columns of the
252   ///     last operation. The default output_columns will have the same
253   ///     name as the input columns, i.e., the columns will be replaced
254   /// \param[in] project_columns A list of column names to project
255   /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
256   /// \return Shared pointer to the current MapDataset
257   std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations,
258                                   const std::vector<std::string> &input_columns = {},
259                                   const std::vector<std::string> &output_columns = {},
260                                   const std::vector<std::string> &project_columns = {},
261                                   const std::shared_ptr<DatasetCache> &cache = nullptr,
262                                   std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
263     std::vector<std::shared_ptr<TensorOperation>> transform_ops;
264     (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
265                          [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
266     return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
267                                         VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
268                                         callbacks);
269   }
270 
271   /// \brief Function to create a Project Dataset
272   /// \notes Applies project to the dataset
273   /// \param[in] columns The name of columns to project
274   /// \return Shared pointer to the current Dataset
Project(const std::vector<std::string> & columns)275   std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) {
276     return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
277   }
278 
279   /// \brief Function to create a Shuffle Dataset
280   /// \notes Randomly shuffles the rows of this dataset
281   /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
282   /// \return Shared pointer to the current ShuffleDataset
Shuffle(int32_t buffer_size)283   std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) {
284     return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size);
285   }
286 
IRNode()287   std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
288 
289  protected:
290   std::shared_ptr<TreeGetters> tree_getters_;
291   std::shared_ptr<DatasetNode> ir_node_;
292 
293  private:
294   // Char interface(CharIF) of GetColumnNames
295   std::vector<std::vector<char>> GetColumnNamesCharIF();
296 
297   // Char interface(CharIF) of GetClassIndexing
298   std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();
299 
300   // Char interface(CharIF) of CreateIterator
301   std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs);
302 
303   // Char interface(CharIF) of DeviceQueue
304   bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
305                          int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);
306 
307   // Char interface(CharIF) of Save
308   bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);
309 };
310 
311 class SchemaObj {
312  public:
313   /// \brief Constructor
SchemaObj(StringToChar (schema_file))314   explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {}
315 
316   /// \brief Destructor
317   ~SchemaObj() = default;
318 
319   /// \brief SchemaObj Init function
320   /// \return bool true if schema initialization is successful
321   Status Init();
322 
323   /// \brief Add new column to the schema with unknown shape of rank 1
324   /// \param[in] name Name of the column.
325   /// \param[in] ms_type Data type of the column(mindspore::DataType).
326   /// \return Status code
add_column(const std::string & name,mindspore::DataType ms_type)327   Status add_column(const std::string &name, mindspore::DataType ms_type) {
328     return add_column_char(StringToChar(name), ms_type);
329   }
330 
331   /// \brief Add new column to the schema with unknown shape of rank 1
332   /// \param[in] name Name of the column.
333   /// \param[in] ms_type Data type of the column(std::string).
334   /// \param[in] shape Shape of the column.
335   /// \return Status code
add_column(const std::string & name,const std::string & ms_type)336   Status add_column(const std::string &name, const std::string &ms_type) {
337     return add_column_char(StringToChar(name), StringToChar(ms_type));
338   }
339 
340   /// \brief Add new column to the schema
341   /// \param[in] name Name of the column.
342   /// \param[in] ms_type Data type of the column(mindspore::DataType).
343   /// \param[in] shape Shape of the column.
344   /// \return Status code
add_column(const std::string & name,mindspore::DataType ms_type,const std::vector<int32_t> & shape)345   Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) {
346     return add_column_char(StringToChar(name), ms_type, shape);
347   }
348 
349   /// \brief Add new column to the schema
350   /// \param[in] name Name of the column.
351   /// \param[in] ms_type Data type of the column(std::string).
352   /// \param[in] shape Shape of the column.
353   /// \return Status code
add_column(const std::string & name,const std::string & ms_type,const std::vector<int32_t> & shape)354   Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) {
355     return add_column_char(StringToChar(name), StringToChar(ms_type), shape);
356   }
357 
358   /// \brief Get a JSON string of the schema
359   /// \return JSON string of the schema
to_json()360   std::string to_json() { return CharToString(to_json_char()); }
361 
362   /// \brief Get a JSON string of the schema
to_string()363   std::string to_string() { return to_json(); }
364 
365   /// \brief Set a new value to dataset_type
366   void set_dataset_type(std::string dataset_type);
367 
368   /// \brief Set a new value to num_rows
369   void set_num_rows(int32_t num_rows);
370 
371   /// \brief Get the current num_rows
372   int32_t get_num_rows() const;
373 
374   /// \brief Get schema file from JSON file
375   /// \param[in] json_string Name of JSON file to be parsed.
376   /// \return Status code
FromJSONString(const std::string & json_string)377   Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); }
378 
379   /// \brief Parse and add column information
380   /// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file.
381   /// \return Status code
ParseColumnString(const std::string & json_string)382   Status ParseColumnString(const std::string &json_string) {
383     return ParseColumnStringCharIF(StringToChar(json_string));
384   }
385 
386  private:
387   // Char constructor of SchemaObj
388   explicit SchemaObj(const std::vector<char> &schema_file);
389 
390   // Char interface of add_column
391   Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type);
392 
393   Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type);
394 
395   Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape);
396 
397   Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type,
398                          const std::vector<int32_t> &shape);
399 
400   // Char interface of to_json
401   const std::vector<char> to_json_char();
402 
403   // Char interface of FromJSONString
404   Status FromJSONStringCharIF(const std::vector<char> &json_string);
405 
406   // Char interface of ParseColumnString
407   Status ParseColumnStringCharIF(const std::vector<char> &json_string);
408 
409   struct Data;
410   std::shared_ptr<Data> data_;
411 };
412 
413 class BatchDataset : public Dataset {
414  public:
415   BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false);
416   ~BatchDataset() = default;
417 };
418 
419 class MapDataset : public Dataset {
420  public:
421   MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
422              const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
423              const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
424              std::vector<std::shared_ptr<DSCallback>> callbacks);
425   ~MapDataset() = default;
426 };
427 
428 class ProjectDataset : public Dataset {
429  public:
430   ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns);
431   ~ProjectDataset() = default;
432 };
433 
434 class ShuffleDataset : public Dataset {
435  public:
436   ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size);
437   ~ShuffleDataset() = default;
438 };
439 
440 /// \brief Function to create a SchemaObj.
441 /// \param[in] schema_file Path of schema file.
442 /// \note The reason for using this API is that std::string will be constrained by the
443 ///    compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction.
444 /// \return Shared pointer to the current schema.
445 std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file);
446 
447 /// \brief Function to create a SchemaObj.
448 /// \param[in] schema_file Path of schema file.
449 /// \return Shared pointer to the current schema.
450 inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") {
451   return SchemaCharIF(StringToChar(schema_file));
452 }
453 
454 class AlbumDataset : public Dataset {
455  public:
456   AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
457                const std::vector<std::vector<char>> &column_names, bool decode, const std::shared_ptr<Sampler> &sampler,
458                const std::shared_ptr<DatasetCache> &cache);
459   AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
460                const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler,
461                const std::shared_ptr<DatasetCache> &cache);
462   AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
463                const std::vector<std::vector<char>> &column_names, bool decode,
464                const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
465   ~AlbumDataset() = default;
466 };
467 
468 /// \brief Function to create an AlbumDataset
469 /// \notes The generated dataset is specified through setting a schema
470 /// \param[in] dataset_dir Path to the root directory that contains the dataset
471 /// \param[in] data_schema Path to dataset schema file
472 /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns.
473 ///     (default = {})
474 /// \param[in] decode the option to decode the images in dataset (default = false)
475 /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
476 /// given,
477 ///     a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
478 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
479 /// \return Shared pointer to the current Dataset
480 inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
481                                            const std::vector<std::string> &column_names = {}, bool decode = false,
482                                            const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
483                                            const std::shared_ptr<DatasetCache> &cache = nullptr) {
484   return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
485                                         VectorStringToChar(column_names), decode, sampler, cache);
486 }
487 /// \brief Function to create an AlbumDataset
488 /// \notes The generated dataset is specified through setting a schema
489 /// \param[in] dataset_dir Path to the root directory that contains the dataset
490 /// \param[in] data_schema Path to dataset schema file
491 /// \param[in] column_names Column names used to specify columns to load
492 /// \param[in] decode the option to decode the images in dataset
493 /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
494 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
495 /// \return Shared pointer to the current Dataset
496 inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
497                                            const std::vector<std::string> &column_names, bool decode,
498                                            const Sampler *sampler,
499                                            const std::shared_ptr<DatasetCache> &cache = nullptr) {
500   return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
501                                         VectorStringToChar(column_names), decode, sampler, cache);
502 }
503 /// \brief Function to create an AlbumDataset
504 /// \notes The generated dataset is specified through setting a schema
505 /// \param[in] dataset_dir Path to the root directory that contains the dataset
506 /// \param[in] data_schema Path to dataset schema file
507 /// \param[in] column_names Column names used to specify columns to load
508 /// \param[in] decode the option to decode the images in dataset
509 /// \param[in] sampler Sampler object used to choose samples from the dataset.
510 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
511 /// \return Shared pointer to the current Dataset
512 inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
513                                            const std::vector<std::string> &column_names, bool decode,
514                                            const std::reference_wrapper<Sampler> sampler,
515                                            const std::shared_ptr<DatasetCache> &cache = nullptr) {
516   return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
517                                         VectorStringToChar(column_names), decode, sampler, cache);
518 }
519 
520 class MnistDataset : public Dataset {
521  public:
522   MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
523                const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
524   MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
525                const std::shared_ptr<DatasetCache> &cache);
526   MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
527                const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
528   ~MnistDataset() = default;
529 };
530 
531 /// \brief Function to create a MnistDataset
532 /// \notes The generated dataset has two columns ["image", "label"]
533 /// \param[in] dataset_dir Path to the root directory that contains the dataset
534 /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all").
535 /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
536 /// given,
537 ///     a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
538 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
539 /// \return Shared pointer to the current MnistDataset
540 inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = "all",
541                                            const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
542                                            const std::shared_ptr<DatasetCache> &cache = nullptr) {
543   return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
544 }
545 
546 /// \brief Function to create a MnistDataset
547 /// \notes The generated dataset has two columns ["image", "label"]
548 /// \param[in] dataset_dir Path to the root directory that contains the dataset
549 /// \param[in] usage of MNIST, can be "train", "test" or "all"
550 /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
551 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
552 /// \return Shared pointer to the current MnistDataset
553 inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
554                                            const Sampler *sampler,
555                                            const std::shared_ptr<DatasetCache> &cache = nullptr) {
556   return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
557 }
558 
559 /// \brief Function to create a MnistDataset
560 /// \notes The generated dataset has two columns ["image", "label"]
561 /// \param[in] dataset_dir Path to the root directory that contains the dataset
562 /// \param[in] usage of MNIST, can be "train", "test" or "all"
563 /// \param[in] sampler Sampler object used to choose samples from the dataset.
564 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
565 /// \return Shared pointer to the current MnistDataset
566 inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
567                                            const std::reference_wrapper<Sampler> sampler,
568                                            const std::shared_ptr<DatasetCache> &cache = nullptr) {
569   return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
570 }
571 }  // namespace dataset
572 }  // namespace mindspore
573 
574 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
575