• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "minddata/dataset/core/tensor.h"
27 
28 #include "minddata/dataset/engine/data_schema.h"
29 #include "minddata/dataset/engine/datasetops/parallel_op.h"
30 #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
31 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
32 #ifndef ENABLE_ANDROID
33 #include "minddata/dataset/kernels/image/image_utils.h"
34 #else
35 #include "minddata/dataset/kernels/image/lite_image_utils.h"
36 #endif
37 #include "minddata/dataset/util/path.h"
38 #include "minddata/dataset/util/queue.h"
39 #include "minddata/dataset/util/status.h"
40 #include "minddata/dataset/util/wait_post.h"
41 
42 namespace mindspore {
43 namespace dataset {
44 // Forward declares
45 template <typename T>
46 class Queue;
47 
48 using CoordinateRow = std::vector<std::vector<float>>;
49 
50 class CocoOp : public MappableLeafOp {
51  public:
52   enum class TaskType { Detection = 0, Stuff = 1, Panoptic = 2, Keypoint = 3, Captioning = 4 };
53 
54   class Builder {
55    public:
56     // Constructor for Builder class of ImageFolderOp
57     // @param  uint32_t numWrks - number of parallel workers
58     // @param dir - directory folder got ImageNetFolder
59     Builder();
60 
61     // Destructor.
62     ~Builder() = default;
63 
64     // Setter method.
65     // @param const std::string & build_dir
66     // @return Builder setter method returns reference to the builder.
SetDir(const std::string & build_dir)67     Builder &SetDir(const std::string &build_dir) {
68       builder_dir_ = build_dir;
69       return *this;
70     }
71 
72     // Setter method.
73     // @param const std::string & build_file
74     // @return Builder setter method returns reference to the builder.
SetFile(const std::string & build_file)75     Builder &SetFile(const std::string &build_file) {
76       builder_file_ = build_file;
77       return *this;
78     }
79 
80     // Setter method.
81     // @param const std::string & task_type
82     // @return Builder setter method returns reference to the builder.
SetTask(const std::string & task_type)83     Builder &SetTask(const std::string &task_type) {
84       if (task_type == "Detection") {
85         builder_task_type_ = TaskType::Detection;
86       } else if (task_type == "Stuff") {
87         builder_task_type_ = TaskType::Stuff;
88       } else if (task_type == "Panoptic") {
89         builder_task_type_ = TaskType::Panoptic;
90       } else if (task_type == "Keypoint") {
91         builder_task_type_ = TaskType::Keypoint;
92       }
93       return *this;
94     }
95 
96     // Setter method.
97     // @param int32_t num_workers
98     // @return Builder setter method returns reference to the builder.
SetNumWorkers(int32_t num_workers)99     Builder &SetNumWorkers(int32_t num_workers) {
100       builder_num_workers_ = num_workers;
101       return *this;
102     }
103 
104     // Setter method.
105     // @param int32_t op_connector_size
106     // @return Builder setter method returns reference to the builder.
SetOpConnectorSize(int32_t op_connector_size)107     Builder &SetOpConnectorSize(int32_t op_connector_size) {
108       builder_op_connector_size_ = op_connector_size;
109       return *this;
110     }
111 
112     // Setter method.
113     // @param std::shared_ptr<Sampler> sampler
114     // @return Builder setter method returns reference to the builder.
SetSampler(std::shared_ptr<SamplerRT> sampler)115     Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
116       builder_sampler_ = std::move(sampler);
117       return *this;
118     }
119 
120     // Setter method.
121     // @param bool do_decode
122     // @return Builder setter method returns reference to the builder.
SetDecode(bool do_decode)123     Builder &SetDecode(bool do_decode) {
124       builder_decode_ = do_decode;
125       return *this;
126     }
127 
128     // Check validity of input args
129     // @return Status The status code returned
130     Status SanityCheck();
131 
132     // The builder "Build" method creates the final object.
133     // @param std::shared_ptr<CocoOp> *op - DatasetOp
134     // @return Status The status code returned
135     Status Build(std::shared_ptr<CocoOp> *op);
136 
137    private:
138     bool builder_decode_;
139     std::string builder_dir_;
140     std::string builder_file_;
141     TaskType builder_task_type_;
142     int32_t builder_num_workers_;
143     int32_t builder_op_connector_size_;
144     int32_t builder_rows_per_buffer_;
145     std::shared_ptr<SamplerRT> builder_sampler_;
146     std::unique_ptr<DataSchema> builder_schema_;
147   };
148 
149 #ifdef ENABLE_PYTHON
150   /// \brief Constructor.
151   /// \param[in] task_type Task type of Coco.
152   /// \param[in] image_folder_path Image folder path of Coco.
153   /// \param[in] annotation_path Annotation json path of Coco.
154   /// \param[in] num_workers Number of workers reading images in parallel.
155   /// \param[in] queue_size Connector queue size.
156   /// \param[in] num_samples Number of samples to read.
157   /// \param[in] decode Whether to decode images.
158   /// \param[in] data_schema The schema of the Coco dataset.
159   /// \param[in] sampler Sampler tells CocoOp what to read.
160   /// \param[in] decrypt - Image decryption function, which accepts the path of the encrypted image file
161   ///     and returns the decrypted bytes data. Default: None, no decryption.
162   CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
163          int32_t num_workers, int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema,
164          std::shared_ptr<SamplerRT> sampler, bool extra_metadata, py::function decrypt = py::none());
165 #else
166   /// \brief Constructor.
167   /// \param[in] task_type Task type of Coco.
168   /// \param[in] image_folder_path Image folder path of Coco.
169   /// \param[in] annotation_path Annotation json path of Coco.
170   /// \param[in] num_workers Number of workers reading images in parallel.
171   /// \param[in] queue_size Connector queue size.
172   /// \param[in] num_samples Number of samples to read.
173   /// \param[in] decode Whether to decode images.
174   /// \param[in] data_schema The schema of the Coco dataset.
175   /// \param[in] sampler Sampler tells CocoOp what to read.
176   CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
177          int32_t num_workers, int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema,
178          std::shared_ptr<SamplerRT> sampler, bool extra_metadata);
179 #endif
180 
181   /// \brief Destructor.
182   ~CocoOp() = default;
183 
184   /// \brief A print method typically used for debugging.
185   /// \param[out] out The output stream to write output to.
186   /// \param[in] show_all A bool to control if you want to show all info or just a summary.
187   void Print(std::ostream &out, bool show_all) const override;
188 
189   /// \param[out] count Output rows number of CocoDataset.
190   Status CountTotalRows(int64_t *count);
191 
192   /// \brief Op name getter.
193   /// \return Name of the current Op.
Name()194   std::string Name() const override { return "CocoOp"; }
195 
196   /// \brief Gets the class indexing.
197   /// \return Status The status code returned.
198   Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
199 
200  private:
201   /// \brief Load a tensor row according to image id.
202   /// \param[in] row_id Id for this tensor row.
203   /// \param[out] row Image & target read into this tensor row.
204   /// \return Status The status code returned.
205   Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
206 
207   /// \brief Load a tensor row with vector which a vector to a tensor, for "Detection" task.
208   /// \param[in] row_id Id for this tensor row.
209   /// \param[in] image_id Image id.
210   /// \param[in] image Image tensor.
211   /// \param[in] coordinate Coordinate tensor.
212   /// \param[out] row Image & target read into this tensor row.
213   /// \return Status The status code returned.
214   Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
215                                 std::shared_ptr<Tensor> coordinate, TensorRow *trow);
216 
217   /// \brief Load a tensor row with vector which a vector to a tensor, for "Stuff/Keypoint" task.
218   /// \param[in] row_id Id for this tensor row.
219   /// \param[in] image_id Image id.
220   /// \param[in] image Image tensor.
221   /// \param[in] coordinate Coordinate tensor.
222   /// \param[out] row Image & target read into this tensor row.
223   /// \return Status The status code returned.
224   Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
225                              std::shared_ptr<Tensor> coordinate, TensorRow *trow);
226 
227   /// \brief Load a tensor row with vector which a vector to multi-tensor, for "Panoptic" task.
228   /// \param[in] row_id Id for this tensor row.
229   /// \param[in] image_id Image id.
230   /// \param[in] image Image tensor.
231   /// \param[in] coordinate Coordinate tensor.
232   /// \param[out] row Image & target read into this tensor row.
233   /// \return Status The status code returned.
234   Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
235                           std::shared_ptr<Tensor> coordinate, TensorRow *trow);
236 
237   /// \brief Load a tensor row with vector which a vector to multi-tensor, for "Captioning" task.
238   /// \param[in] row_id Id for this tensor row.
239   /// \param[in] image_id Image id.
240   /// \param[in] image Image tensor.
241   /// \param[in] captions Captions tensor.
242   /// \param[out] trow Image & target read into this tensor row.
243   /// \return Status The status code returned.
244   Status LoadCaptioningTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
245                                  std::shared_ptr<Tensor> captions, TensorRow *trow);
246 
247   /// \param[in] path Path to the image file.
248   /// \param[out] tensor Returned tensor.
249   /// \return Status The status code returned.
250   Status ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor) const;
251 
252   /// \brief Read annotation from Annotation folder.
253   /// \return Status The status code returned.
254   Status PrepareData() override;
255 
256   /// \param[in] image_tree Image tree of json.
257   /// \param[out] image_vec Image id list of json.
258   /// \return Status The status code returned.
259   Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec);
260 
261   /// \param[in] categories_tree Categories tree of json.
262   /// \return Status The status code returned.
263   Status CategoriesColumnLoad(const nlohmann::json &categories_tree);
264 
265   /// \param[in] categories_tree Categories tree of json.
266   /// \param[in] image_file Current image name in annotation.
267   /// \param[in] id Current unique id of annotation.
268   /// \return Status The status code returned.
269   Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
270 
271   /// \param[in] categories_tree Categories tree of json.
272   /// \param[in] image_file Current image name in annotation.
273   /// \param[in] id Current unique id of annotation.
274   /// \return Status The status code returned.
275   Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
276 
277   /// \param[in] categories_tree Categories tree of json.
278   /// \param[in] image_file Current image name in annotation.
279   /// \param[in] id Current unique id of annotation.
280   /// \return Status The status code returned.
281   Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
282 
283   /// \param[in] categories_tree Categories tree of json.
284   /// \param[in] image_file Current image name in annotation.
285   /// \param[in] image_id Current unique id of annotation.
286   /// \return Status The status code returned.
287   Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file,
288                             const int32_t &image_id);
289 
290   /// \brief Function for finding a caption in annotation_tree.
291   /// \param[in] annotation_tree Annotation tree of json.
292   /// \param[in] image_file Current image name in annotation.
293   /// \param[in] id Current unique id of annotation.
294   /// \return Status The status code returned.
295   Status CaptionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
296 
297   template <typename T>
298   Status SearchNodeInJson(const nlohmann::json &input_tree, std::string node_name, T *output_node);
299 
300   /// \brief Private function for computing the assignment of the column name map.
301   /// \return Status The status code returned.
302   Status ComputeColMap() override;
303 
304   bool decode_;
305   std::string image_folder_path_;
306   std::string annotation_path_;
307   TaskType task_type_;
308   std::unique_ptr<DataSchema> data_schema_;
309   bool extra_metadata_;
310 
311   std::vector<std::string> image_ids_;
312   std::map<int32_t, std::string> image_index_;
313   std::vector<std::pair<std::string, std::vector<int32_t>>> label_index_;
314   std::map<std::string, CoordinateRow> coordinate_map_;
315   std::map<std::string, std::vector<uint32_t>> simple_item_map_;
316   std::map<std::string, std::vector<std::string>> captions_map_;
317   std::set<uint32_t> category_set_;
318 #ifdef ENABLE_PYTHON
319   py::function decrypt_;
320 #endif
321 };
322 }  // namespace dataset
323 }  // namespace mindspore
324 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_
325