1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_ 19 20 #include <sys/stat.h> 21 #include <unistd.h> 22 23 #include <algorithm> 24 #include <map> 25 #include <memory> 26 #include <set> 27 #include <string> 28 #include <unordered_map> 29 #include <unordered_set> 30 #include <utility> 31 #include <vector> 32 33 #include "include/api/dual_abi_helper.h" 34 #include "include/api/types.h" 35 #include "include/dataset/iterator.h" 36 #include "include/dataset/samplers.h" 37 #include "include/dataset/transforms.h" 38 39 namespace mindspore { 40 namespace dataset { 41 42 class Tensor; 43 class TensorShape; 44 class TreeAdapter; 45 class TreeAdapterLite; 46 class TreeGetters; 47 48 class DatasetCache; 49 class DatasetNode; 50 51 class Iterator; 52 53 class TensorOperation; 54 class SchemaObj; 55 class SamplerObj; 56 57 // Dataset classes (in alphabetical order) 58 class BatchDataset; 59 class MapDataset; 60 class ProjectDataset; 61 class ShuffleDataset; 62 class DSCallback; 63 64 /// \class Dataset datasets.h 65 /// \brief A base class to represent a dataset in the data pipeline. 66 class Dataset : public std::enable_shared_from_this<Dataset> { 67 public: 68 // need friend class so they can access the children_ field 69 friend class Iterator; 70 friend class TransferNode; 71 72 /// \brief Constructor 73 Dataset(); 74 75 /// \brief Destructor 76 ~Dataset() = default; 77 78 /// \brief Gets the dataset size 79 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting 80 /// dataset size at the expense of accuracy. 81 /// \return dataset size. If failed, return -1 82 int64_t GetDatasetSize(bool estimate = false); 83 84 /// \brief Gets the output type 85 /// \return a vector of DataType. If failed, return an empty vector 86 std::vector<mindspore::DataType> GetOutputTypes(); 87 88 /// \brief Gets the output shape 89 /// \return a vector of TensorShape. If failed, return an empty vector 90 std::vector<std::vector<int64_t>> GetOutputShapes(); 91 92 /// \brief Gets the batch size 93 /// \return int64_t 94 int64_t GetBatchSize(); 95 96 /// \brief Gets the repeat count 97 /// \return int64_t 98 int64_t GetRepeatCount(); 99 100 /// \brief Gets the number of classes 101 /// \return number of classes. If failed, return -1 102 int64_t GetNumClasses(); 103 104 /// \brief Gets the column names 105 /// \return Names of the columns. If failed, return an empty vector GetColumnNames()106 std::vector<std::string> GetColumnNames() { return VectorCharToString(GetColumnNamesCharIF()); } 107 108 /// \brief Gets the class indexing 109 /// \return a map of ClassIndexing. If failed, return an empty map GetClassIndexing()110 std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing() { 111 return ClassIndexCharToString(GetClassIndexingCharIF()); 112 } 113 114 /// \brief Setter function for runtime number of workers 115 /// \param[in] num_workers The number of threads in this operator 116 /// \return Shared pointer to the original object 117 std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers); 118 119 /// \brief Function to create an PullBasedIterator over the Dataset 120 /// \param[in] columns List of columns to be used to specify the order of columns 121 /// \return Shared pointer to the Iterator 122 std::shared_ptr<PullIterator> CreatePullBasedIterator(std::vector<std::vector<char>> columns = {}); 123 124 /// \brief Function to create an Iterator over the Dataset pipeline 125 /// \param[in] columns List of columns to be used to specify the order of columns 126 /// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs. 127 /// An empty row is returned at the end of each epoch 128 /// \return Shared pointer to the Iterator 129 std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1) { 130 return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs); 131 } 132 133 /// \brief Function to transfer data through a device. 134 /// \notes If device is Ascend, features of data will be transferred one by one. The limitation 135 /// of data transmission per time is 256M. 136 /// \param[in] queue_name Channel name (default="", create new unique name). 137 /// \param[in] device_type Type of device (default="", get from MSContext). 138 /// \param[in] device_id id of device (default=1, get from MSContext). 139 /// \param[in] num_epochs Number of epochs (default=-1, infinite epochs). 140 /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true). 141 /// \param[in] total_batches Number of batches to be sent to the device (default=0, all data). 142 /// \param[in] create_data_info_queue Whether to create queue which stores types and shapes 143 /// of data or not(default=false). 144 /// \return Returns true if no error encountered else false. 145 bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0, 146 int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0, 147 bool create_data_info_queue = false) { 148 return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end, 149 total_batches, create_data_info_queue); 150 } 151 152 /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline 153 /// \note Usage restrictions: 154 /// 1. Supported dataset formats: 'mindrecord' only 155 /// 2. To save the samples in order, set dataset's shuffle to false and num_files to 1. 156 /// 3. Before calling the function, do not use batch operator, repeat operator or data augmentation operators 157 /// with random attribute in map operator. 158 /// 4. Mindrecord does not support bool, uint64, multi-dimensional uint8(drop dimension) nor 159 /// multi-dimensional string. 160 /// \param[in] file_name Path to dataset file 161 /// \param[in] num_files Number of dataset files (default=1) 162 /// \param[in] file_type Dataset format (default="mindrecord") 163 /// \return Returns true if no error encountered else false 164 bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") { 165 return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type)); 166 } 167 168 /// \brief Function to create a BatchDataset 169 /// \notes Combines batch_size number of consecutive rows into batches 170 /// \param[in] batch_size The number of rows each batch is created with 171 /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete 172 /// batch. If true, and if there are less than batch_size rows 173 /// available to make the last batch, then those rows will 174 /// be dropped and not propagated to the next node 175 /// \return Shared pointer to the current BatchDataset 176 std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false); 177 178 /// \brief Function to create a MapDataset 179 /// \notes Applies each operation in operations to this dataset 180 /// \param[in] operations Vector of raw pointers to TensorTransform objects to be applied on the dataset. Operations 181 /// are applied in the order they appear in this list 182 /// \param[in] input_columns Vector of the names of the columns that will be passed to the first 183 /// operation as input. The size of this list must match the number of 184 /// input columns expected by the first operator. The default input_columns 185 /// is the first column 186 /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation 187 /// This parameter is mandatory if len(input_columns) != len(output_columns) 188 /// The size of this list must match the number of output columns of the 189 /// last operation. The default output_columns will have the same 190 /// name as the input columns, i.e., the columns will be replaced 191 /// \param[in] project_columns A list of column names to project 192 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 193 /// \return Shared pointer to the current MapDataset 194 std::shared_ptr<MapDataset> Map(std::vector<TensorTransform *> operations, 195 const std::vector<std::string> &input_columns = {}, 196 const std::vector<std::string> &output_columns = {}, 197 const std::vector<std::string> &project_columns = {}, 198 const std::shared_ptr<DatasetCache> &cache = nullptr, 199 std::vector<std::shared_ptr<DSCallback>> callbacks = {}) { 200 std::vector<std::shared_ptr<TensorOperation>> transform_ops; 201 (void)std::transform( 202 operations.begin(), operations.end(), std::back_inserter(transform_ops), 203 [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; }); 204 return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns), 205 VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache, 206 callbacks); 207 } 208 209 /// \brief Function to create a MapDataset 210 /// \notes Applies each operation in operations to this dataset 211 /// \param[in] operations Vector of shared pointers to TensorTransform objects to be applied on the dataset. 212 /// Operations are applied in the order they appear in this list 213 /// \param[in] input_columns Vector of the names of the columns that will be passed to the first 214 /// operation as input. The size of this list must match the number of 215 /// input columns expected by the first operator. The default input_columns 216 /// is the first column 217 /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation 218 /// This parameter is mandatory if len(input_columns) != len(output_columns) 219 /// The size of this list must match the number of output columns of the 220 /// last operation. The default output_columns will have the same 221 /// name as the input columns, i.e., the columns will be replaced 222 /// \param[in] project_columns A list of column names to project 223 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 224 /// \return Shared pointer to the current MapDataset 225 std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations, 226 const std::vector<std::string> &input_columns = {}, 227 const std::vector<std::string> &output_columns = {}, 228 const std::vector<std::string> &project_columns = {}, 229 const std::shared_ptr<DatasetCache> &cache = nullptr, 230 std::vector<std::shared_ptr<DSCallback>> callbacks = {}) { 231 std::vector<std::shared_ptr<TensorOperation>> transform_ops; 232 (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops), 233 [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> { 234 return op != nullptr ? op->Parse() : nullptr; 235 }); 236 return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns), 237 VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache, 238 callbacks); 239 } 240 241 /// \brief Function to create a MapDataset 242 /// \notes Applies each operation in operations to this dataset 243 /// \param[in] operations Vector of TensorTransform objects to be applied on the dataset. Operations are applied in 244 /// the order they appear in this list 245 /// \param[in] input_columns Vector of the names of the columns that will be passed to the first 246 /// operation as input. The size of this list must match the number of 247 /// input columns expected by the first operator. The default input_columns 248 /// is the first column 249 /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation 250 /// This parameter is mandatory if len(input_columns) != len(output_columns) 251 /// The size of this list must match the number of output columns of the 252 /// last operation. The default output_columns will have the same 253 /// name as the input columns, i.e., the columns will be replaced 254 /// \param[in] project_columns A list of column names to project 255 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 256 /// \return Shared pointer to the current MapDataset 257 std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations, 258 const std::vector<std::string> &input_columns = {}, 259 const std::vector<std::string> &output_columns = {}, 260 const std::vector<std::string> &project_columns = {}, 261 const std::shared_ptr<DatasetCache> &cache = nullptr, 262 std::vector<std::shared_ptr<DSCallback>> callbacks = {}) { 263 std::vector<std::shared_ptr<TensorOperation>> transform_ops; 264 (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops), 265 [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); }); 266 return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns), 267 VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache, 268 callbacks); 269 } 270 271 /// \brief Function to create a Project Dataset 272 /// \notes Applies project to the dataset 273 /// \param[in] columns The name of columns to project 274 /// \return Shared pointer to the current Dataset Project(const std::vector<std::string> & columns)275 std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) { 276 return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns)); 277 } 278 279 /// \brief Function to create a Shuffle Dataset 280 /// \notes Randomly shuffles the rows of this dataset 281 /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling 282 /// \return Shared pointer to the current ShuffleDataset Shuffle(int32_t buffer_size)283 std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) { 284 return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size); 285 } 286 IRNode()287 std::shared_ptr<DatasetNode> IRNode() { return ir_node_; } 288 289 protected: 290 std::shared_ptr<TreeGetters> tree_getters_; 291 std::shared_ptr<DatasetNode> ir_node_; 292 293 private: 294 // Char interface(CharIF) of GetColumnNames 295 std::vector<std::vector<char>> GetColumnNamesCharIF(); 296 297 // Char interface(CharIF) of GetClassIndexing 298 std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF(); 299 300 // Char interface(CharIF) of CreateIterator 301 std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs); 302 303 // Char interface(CharIF) of DeviceQueue 304 bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id, 305 int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue); 306 307 // Char interface(CharIF) of Save 308 bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type); 309 }; 310 311 class SchemaObj { 312 public: 313 /// \brief Constructor SchemaObj(StringToChar (schema_file))314 explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {} 315 316 /// \brief Destructor 317 ~SchemaObj() = default; 318 319 /// \brief SchemaObj Init function 320 /// \return bool true if schema initialization is successful 321 Status Init(); 322 323 /// \brief Add new column to the schema with unknown shape of rank 1 324 /// \param[in] name Name of the column. 325 /// \param[in] ms_type Data type of the column(mindspore::DataType). 326 /// \return Status code add_column(const std::string & name,mindspore::DataType ms_type)327 Status add_column(const std::string &name, mindspore::DataType ms_type) { 328 return add_column_char(StringToChar(name), ms_type); 329 } 330 331 /// \brief Add new column to the schema with unknown shape of rank 1 332 /// \param[in] name Name of the column. 333 /// \param[in] ms_type Data type of the column(std::string). 334 /// \param[in] shape Shape of the column. 335 /// \return Status code add_column(const std::string & name,const std::string & ms_type)336 Status add_column(const std::string &name, const std::string &ms_type) { 337 return add_column_char(StringToChar(name), StringToChar(ms_type)); 338 } 339 340 /// \brief Add new column to the schema 341 /// \param[in] name Name of the column. 342 /// \param[in] ms_type Data type of the column(mindspore::DataType). 343 /// \param[in] shape Shape of the column. 344 /// \return Status code add_column(const std::string & name,mindspore::DataType ms_type,const std::vector<int32_t> & shape)345 Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) { 346 return add_column_char(StringToChar(name), ms_type, shape); 347 } 348 349 /// \brief Add new column to the schema 350 /// \param[in] name Name of the column. 351 /// \param[in] ms_type Data type of the column(std::string). 352 /// \param[in] shape Shape of the column. 353 /// \return Status code add_column(const std::string & name,const std::string & ms_type,const std::vector<int32_t> & shape)354 Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) { 355 return add_column_char(StringToChar(name), StringToChar(ms_type), shape); 356 } 357 358 /// \brief Get a JSON string of the schema 359 /// \return JSON string of the schema to_json()360 std::string to_json() { return CharToString(to_json_char()); } 361 362 /// \brief Get a JSON string of the schema to_string()363 std::string to_string() { return to_json(); } 364 365 /// \brief Set a new value to dataset_type 366 void set_dataset_type(std::string dataset_type); 367 368 /// \brief Set a new value to num_rows 369 void set_num_rows(int32_t num_rows); 370 371 /// \brief Get the current num_rows 372 int32_t get_num_rows() const; 373 374 /// \brief Get schema file from JSON file 375 /// \param[in] json_string Name of JSON file to be parsed. 376 /// \return Status code FromJSONString(const std::string & json_string)377 Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); } 378 379 /// \brief Parse and add column information 380 /// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file. 381 /// \return Status code ParseColumnString(const std::string & json_string)382 Status ParseColumnString(const std::string &json_string) { 383 return ParseColumnStringCharIF(StringToChar(json_string)); 384 } 385 386 private: 387 // Char constructor of SchemaObj 388 explicit SchemaObj(const std::vector<char> &schema_file); 389 390 // Char interface of add_column 391 Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type); 392 393 Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type); 394 395 Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape); 396 397 Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type, 398 const std::vector<int32_t> &shape); 399 400 // Char interface of to_json 401 const std::vector<char> to_json_char(); 402 403 // Char interface of FromJSONString 404 Status FromJSONStringCharIF(const std::vector<char> &json_string); 405 406 // Char interface of ParseColumnString 407 Status ParseColumnStringCharIF(const std::vector<char> &json_string); 408 409 struct Data; 410 std::shared_ptr<Data> data_; 411 }; 412 413 class BatchDataset : public Dataset { 414 public: 415 BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false); 416 ~BatchDataset() = default; 417 }; 418 419 class MapDataset : public Dataset { 420 public: 421 MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, 422 const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns, 423 const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache, 424 std::vector<std::shared_ptr<DSCallback>> callbacks); 425 ~MapDataset() = default; 426 }; 427 428 class ProjectDataset : public Dataset { 429 public: 430 ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns); 431 ~ProjectDataset() = default; 432 }; 433 434 class ShuffleDataset : public Dataset { 435 public: 436 ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size); 437 ~ShuffleDataset() = default; 438 }; 439 440 /// \brief Function to create a SchemaObj. 441 /// \param[in] schema_file Path of schema file. 442 /// \note The reason for using this API is that std::string will be constrained by the 443 /// compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction. 444 /// \return Shared pointer to the current schema. 445 std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file); 446 447 /// \brief Function to create a SchemaObj. 448 /// \param[in] schema_file Path of schema file. 449 /// \return Shared pointer to the current schema. 450 inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") { 451 return SchemaCharIF(StringToChar(schema_file)); 452 } 453 454 class AlbumDataset : public Dataset { 455 public: 456 AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema, 457 const std::vector<std::vector<char>> &column_names, bool decode, const std::shared_ptr<Sampler> &sampler, 458 const std::shared_ptr<DatasetCache> &cache); 459 AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema, 460 const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler, 461 const std::shared_ptr<DatasetCache> &cache); 462 AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema, 463 const std::vector<std::vector<char>> &column_names, bool decode, 464 const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache); 465 ~AlbumDataset() = default; 466 }; 467 468 /// \brief Function to create an AlbumDataset 469 /// \notes The generated dataset is specified through setting a schema 470 /// \param[in] dataset_dir Path to the root directory that contains the dataset 471 /// \param[in] data_schema Path to dataset schema file 472 /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns. 473 /// (default = {}) 474 /// \param[in] decode the option to decode the images in dataset (default = false) 475 /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not 476 /// given, 477 /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) 478 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 479 /// \return Shared pointer to the current Dataset 480 inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema, 481 const std::vector<std::string> &column_names = {}, bool decode = false, 482 const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(), 483 const std::shared_ptr<DatasetCache> &cache = nullptr) { 484 return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema), 485 VectorStringToChar(column_names), decode, sampler, cache); 486 } 487 /// \brief Function to create an AlbumDataset 488 /// \notes The generated dataset is specified through setting a schema 489 /// \param[in] dataset_dir Path to the root directory that contains the dataset 490 /// \param[in] data_schema Path to dataset schema file 491 /// \param[in] column_names Column names used to specify columns to load 492 /// \param[in] decode the option to decode the images in dataset 493 /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. 494 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 495 /// \return Shared pointer to the current Dataset 496 inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema, 497 const std::vector<std::string> &column_names, bool decode, 498 const Sampler *sampler, 499 const std::shared_ptr<DatasetCache> &cache = nullptr) { 500 return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema), 501 VectorStringToChar(column_names), decode, sampler, cache); 502 } 503 /// \brief Function to create an AlbumDataset 504 /// \notes The generated dataset is specified through setting a schema 505 /// \param[in] dataset_dir Path to the root directory that contains the dataset 506 /// \param[in] data_schema Path to dataset schema file 507 /// \param[in] column_names Column names used to specify columns to load 508 /// \param[in] decode the option to decode the images in dataset 509 /// \param[in] sampler Sampler object used to choose samples from the dataset. 510 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 511 /// \return Shared pointer to the current Dataset 512 inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema, 513 const std::vector<std::string> &column_names, bool decode, 514 const std::reference_wrapper<Sampler> sampler, 515 const std::shared_ptr<DatasetCache> &cache = nullptr) { 516 return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema), 517 VectorStringToChar(column_names), decode, sampler, cache); 518 } 519 520 class MnistDataset : public Dataset { 521 public: 522 MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, 523 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache); 524 MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler, 525 const std::shared_ptr<DatasetCache> &cache); 526 MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, 527 const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache); 528 ~MnistDataset() = default; 529 }; 530 531 /// \brief Function to create a MnistDataset 532 /// \notes The generated dataset has two columns ["image", "label"] 533 /// \param[in] dataset_dir Path to the root directory that contains the dataset 534 /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all"). 535 /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not 536 /// given, 537 /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) 538 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 539 /// \return Shared pointer to the current MnistDataset 540 inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = "all", 541 const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(), 542 const std::shared_ptr<DatasetCache> &cache = nullptr) { 543 return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); 544 } 545 546 /// \brief Function to create a MnistDataset 547 /// \notes The generated dataset has two columns ["image", "label"] 548 /// \param[in] dataset_dir Path to the root directory that contains the dataset 549 /// \param[in] usage of MNIST, can be "train", "test" or "all" 550 /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. 551 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 552 /// \return Shared pointer to the current MnistDataset 553 inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage, 554 const Sampler *sampler, 555 const std::shared_ptr<DatasetCache> &cache = nullptr) { 556 return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); 557 } 558 559 /// \brief Function to create a MnistDataset 560 /// \notes The generated dataset has two columns ["image", "label"] 561 /// \param[in] dataset_dir Path to the root directory that contains the dataset 562 /// \param[in] usage of MNIST, can be "train", "test" or "all" 563 /// \param[in] sampler Sampler object used to choose samples from the dataset. 564 /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). 565 /// \return Shared pointer to the current MnistDataset 566 inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage, 567 const std::reference_wrapper<Sampler> sampler, 568 const std::shared_ptr<DatasetCache> &cache = nullptr) { 569 return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); 570 } 571 } // namespace dataset 572 } // namespace mindspore 573 574 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_ 575