• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_
18 
19 #include <deque>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <algorithm>
24 #include <map>
25 #include <set>
26 #include <utility>
27 #include <vector>
28 #include "minddata/dataset/core/tensor.h"
29 
30 #include "minddata/dataset/engine/data_schema.h"
31 #include "minddata/dataset/engine/datasetops/parallel_op.h"
32 #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
33 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
34 #ifndef ENABLE_ANDROID
35 #include "minddata/dataset/kernels/image/image_utils.h"
36 #else
37 #include "minddata/dataset/kernels/image/lite_image_utils.h"
38 #endif
39 #include "minddata/dataset/util/path.h"
40 #include "minddata/dataset/util/queue.h"
41 #include "minddata/dataset/util/services.h"
42 #include "minddata/dataset/util/status.h"
43 #include "minddata/dataset/util/wait_post.h"
44 
45 namespace mindspore {
46 namespace dataset {
47 /// Forward declares
48 template <typename T>
49 class Queue;
50 
51 using ImageLabelPair = std::shared_ptr<std::pair<std::string, int32_t>>;
52 using FolderImagesPair = std::shared_ptr<std::pair<std::string, std::queue<ImageLabelPair>>>;
53 
54 class ImageFolderOp : public MappableLeafOp {
55  public:
56 #ifdef ENABLE_PYTHON
57   // Constructor
58   // @param int32_t num_wkrs - Num of workers reading images in parallel
59   // @param std::string - dir directory of ImageNetFolder
60   // @param int32_t queue_size - connector queue size
61   // @param bool recursive - read recursively
62   // @param bool do_decode - decode the images after reading
63   // @param std::set<std::string> &exts - set of file extensions to read, if empty, read everything under the dir
64   // @param std::map<std::string, int32_t> &map- map of folder name and class id
65   // @param std::unique_ptr<dataschema> data_schema - schema of data
66   // @param py::function decrypt - Image decryption function, which accepts the path of the encrypted image file
67   //     and returns the decrypted bytes data. Default: None, no decryption.
68   ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode,
69                 const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
70                 std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler,
71                 py::function decrypt = py::none());
72 #else
73   // Constructor
74   // @param int32_t num_wkrs - Num of workers reading images in parallel
75   // @param std::string - dir directory of ImageNetFolder
76   // @param int32_t queue_size - connector queue size
77   // @param bool recursive - read recursively
78   // @param bool do_decode - decode the images after reading
79   // @param std::set<std::string> &exts - set of file extensions to read, if empty, read everything under the dir
80   // @param std::map<std::string, int32_t> &map- map of folder name and class id
81   // @param std::unique_ptr<dataschema> data_schema - schema of data
82   ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode,
83                 const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
84                 std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
85 #endif
86 
87   /// Destructor.
88   ~ImageFolderOp() = default;
89 
90   /// Initialize ImageFOlderOp related var, calls the function to walk all files
91   /// @param - std::string dir file directory to  ImageNetFolder
92   /// @return Status The status code returned
93   Status PrepareData() override;
94 
95   // Worker thread pulls a number of IOBlock from IOBlock Queue, make a TensorRow and push it to Connector
96   // @param int32_t workerId - id of each worker
97   // @return Status The status code returned
98   Status PrescanWorkerEntry(int32_t worker_id);
99 
100   // Method derived from RandomAccess Op, enable Sampler to get all ids for each class
101   // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
102   // @return Status The status code returned
103   Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
104 
105   /// A print method typically used for debugging
106   /// @param out
107   /// @param show_all
108   void Print(std::ostream &out, bool show_all) const override;
109 
110   /// This function is a hack! It is to return the num_class and num_rows. The result
111   /// returned by this function may not be consistent with what image_folder_op is going to return
112   /// user this at your own risk!
113   static Status CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
114                                     int64_t *num_classes, std::map<std::string, int32_t> class_index);
115 
116   /// Op name getter
117   /// @return Name of the current Op
Name()118   std::string Name() const override { return "ImageFolderOp"; }
119 
120   // DatasetName name getter
121   // \return DatasetName of the current Op
122   virtual std::string DatasetName(bool upper = false) const { return upper ? "ImageFolder" : "image folder"; }
123 
124   //// \brief Base-class override for GetNumClasses
125   //// \param[out] num_classes the number of classes
126   //// \return Status of the function
127   Status GetNumClasses(int64_t *num_classes) override;
128 
129   //// \brief Gets the class indexing
130   //// \param[out] output_class_indexing The index mapping of dataset
131   //// \return Status The status code return
132   Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
133 
134  protected:
135   // Load a tensor row according to a pair
136   // @param row_id_type row_id - id for this tensor row
137   // @param ImageLabelPair pair - <imagefile,label>
138   // @param TensorRow row - image & label read into this tensor row
139   // @return Status The status code returned
140   Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
141 
142   /// @param std::string & dir - dir to walk all images
143   /// @param int64_t * cnt - number of non folder files under the current dir
144   /// @return
145   virtual Status RecursiveWalkFolder(Path *dir);
146 
147   /// start walking of all dirs
148   /// @return
149   Status StartAsyncWalk();
150 
151   // Called first when function is called
152   // @return
153   Status RegisterAndLaunchThreads() override;
154 
155   /// Private function for computing the assignment of the column name map.
156   /// @return - Status
157   Status ComputeColMap() override;
158 
159   /// Initialize pull mode, calls PrepareData() within
160   /// @return Status The status code returned
161   Status InitPullMode() override;
162 
163   std::string folder_path_;  // directory of image folder
164   bool recursive_;
165   bool decode_;
166   std::set<std::string> extensions_;  // extensions allowed
167   std::map<std::string, int32_t> class_index_;
168   std::unique_ptr<DataSchema> data_schema_;
169   int64_t sampler_ind_;
170   uint64_t dirname_offset_;
171   std::vector<ImageLabelPair> image_label_pairs_;
172   std::unique_ptr<Queue<std::string>> folder_name_queue_;
173   std::unique_ptr<Queue<FolderImagesPair>> image_name_queue_;
174 #ifdef ENABLE_PYTHON
175   py::function decrypt_;
176 #endif
177 };
178 }  // namespace dataset
179 }  // namespace mindspore
180 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_
181