• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "minddata/dataset/core/tensor.h"
26 
27 #include "minddata/dataset/engine/data_schema.h"
28 #include "minddata/dataset/engine/datasetops/parallel_op.h"
29 #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
30 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
31 #include "minddata/dataset/util/path.h"
32 #include "minddata/dataset/util/queue.h"
33 #include "minddata/dataset/util/services.h"
34 #include "minddata/dataset/util/status.h"
35 #include "minddata/dataset/util/wait_post.h"
36 
37 namespace mindspore {
38 namespace dataset {
39 class CifarOp : public MappableLeafOp {
40  public:
41   enum CifarType { kCifar10, kCifar100 };
42 
43   // Constructor
44   // @param CifarType type - Cifar10 or Cifar100
45   // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
46   // @param uint32_t numWorks - Num of workers reading images in parallel
47   // @param std::string - dir directory of cifar dataset
48   // @param uint32_t - queueSize - connector queue size
49   // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
50   CifarOp(CifarType type, const std::string &usage, int32_t num_works, const std::string &file_dir, int32_t queue_size,
51           std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
52   // Destructor.
53   ~CifarOp() = default;
54 
55   // A print method typically used for debugging
56   // @param out
57   // @param show_all
58   void Print(std::ostream &out, bool show_all) const override;
59 
60   /// Function to count the number of samples in the CIFAR dataset
61   /// @param dir path to the CIFAR directory
62   /// @param isCIFAR10 true if CIFAR10 and false if CIFAR100
63   /// @param count output arg that will hold the actual dataset size
64   /// @return
65   static Status CountTotalRows(const std::string &dir, const std::string &usage, bool isCIFAR10, int64_t *count);
66 
67   /// Op name getter
68   /// @return Name of the current Op
Name()69   std::string Name() const override { return "CifarOp"; }
70 
71  private:
72   // Load a tensor row according to a pair
73   // @param uint64_t index - index need to load
74   // @param TensorRow row - image & label read into this tensor row
75   // @return Status The status code returned
76   Status LoadTensorRow(row_id_type index, TensorRow *trow) override;
77 
78  private:
79   // Read block data from cifar file
80   // @return
81   Status ReadCifarBlockDataAsync();
82 
83   // Called first when function is called
84   // @return
85   Status LaunchThreadsAndInitOp() override;
86 
87   /// Get cifar files in dir
88   /// @return
89   Status GetCifarFiles();
90 
91   /// Read cifar10 data as block
92   /// @return
93   Status ReadCifar10BlockData();
94 
95   /// Read cifar100 data as block
96   /// @return
97   Status ReadCifar100BlockData();
98 
99   /// Parse cifar data
100   /// @return
101   Status ParseCifarData();
102 
103   /// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
104   /// @param (std::map<uint32_t, std::vector<uint32_t >> *cls_ids - val all ids for this class
105   /// @return Status The status code returned
106   Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
107 
108   /// Private function for computing the assignment of the column name map.
109   /// @return - Status
110   Status ComputeColMap() override;
111 
112   CifarType cifar_type_;
113   std::string folder_path_;
114   std::unique_ptr<DataSchema> data_schema_;
115 
116   const std::string usage_;  // can only be either "train" or "test"
117   std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_;
118   std::vector<std::string> cifar_files_;
119   std::vector<std::string> path_record_;
120   std::vector<std::pair<std::shared_ptr<Tensor>, std::vector<uint32_t>>> cifar_image_label_pairs_;
121 };
122 }  // namespace dataset
123 }  // namespace mindspore
124 #endif  /// MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_
125