1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_EXAMPLES_TRANSFER_LEARNING_SRC_DATASET_H_ 18 #define MINDSPORE_LITE_EXAMPLES_TRANSFER_LEARNING_SRC_DATASET_H_ 19 20 #include <tuple> 21 #include <string> 22 #include <vector> 23 24 using DataLabelTuple = std::tuple<char *, int>; 25 using FileTuple = std::tuple<int, std::string>; 26 27 enum database_type { DS_CIFAR10_BINARY = 0, DS_MNIST_BINARY, DS_OTHER }; 28 29 char *ReadFile(const std::string &file, size_t *size); // utility function 30 31 class DataSet { 32 public: DataSet()33 DataSet() {} 34 ~DataSet(); 35 36 int Init(const std::string &data_base_directory, database_type type = DS_OTHER); 37 train_data()38 const std::vector<DataLabelTuple> &train_data() const { return train_data_; } test_data()39 const std::vector<DataLabelTuple> &test_data() const { return test_data_; } val_data()40 const std::vector<DataLabelTuple> &val_data() const { return val_data_; } num_of_classes()41 unsigned int num_of_classes() { return num_of_classes_; } set_expected_data_size(unsigned int expected_data_size)42 void set_expected_data_size(unsigned int expected_data_size) { expected_data_size_ = expected_data_size; } expected_data_size()43 unsigned int expected_data_size() { return expected_data_size_; } 44 45 private: 46 std::vector<FileTuple> ReadDir(const std::string dpath); 47 void InitializeBMPFoldersDatabase(std::string dpath); 48 49 std::vector<DataLabelTuple> train_data_; 50 std::vector<DataLabelTuple> test_data_; 51 std::vector<DataLabelTuple> val_data_; 52 unsigned int num_of_classes_ = 0; 53 unsigned int expected_data_size_ = 0; 54 }; 55 56 #endif // MINDSPORE_LITE_EXAMPLES_TRANSFER_LEARNING_SRC_DATASET_H_ 57