1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ 18 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "minddata/dataset/core/tensor.h" 27 28 #include "minddata/dataset/engine/data_schema.h" 29 #include "minddata/dataset/engine/datasetops/parallel_op.h" 30 #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" 31 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 32 #ifndef ENABLE_ANDROID 33 #include "minddata/dataset/kernels/image/image_utils.h" 34 #else 35 #include "minddata/dataset/kernels/image/lite_image_utils.h" 36 #endif 37 #include "minddata/dataset/util/path.h" 38 #include "minddata/dataset/util/queue.h" 39 #include "minddata/dataset/util/status.h" 40 #include "minddata/dataset/util/wait_post.h" 41 42 namespace mindspore { 43 namespace dataset { 44 // Forward declares 45 template <typename T> 46 class Queue; 47 48 using CoordinateRow = std::vector<std::vector<float>>; 49 50 class CocoOp : public MappableLeafOp { 51 public: 52 enum class TaskType { Detection = 0, Stuff = 1, Panoptic = 2, Keypoint = 3, Captioning = 4 }; 53 54 class Builder { 55 public: 56 // Constructor for Builder class of ImageFolderOp 57 // @param uint32_t numWrks - number of parallel workers 58 // @param dir - directory folder got ImageNetFolder 59 Builder(); 60 61 // Destructor. 62 ~Builder() = default; 63 64 // Setter method. 65 // @param const std::string & build_dir 66 // @return Builder setter method returns reference to the builder. SetDir(const std::string & build_dir)67 Builder &SetDir(const std::string &build_dir) { 68 builder_dir_ = build_dir; 69 return *this; 70 } 71 72 // Setter method. 73 // @param const std::string & build_file 74 // @return Builder setter method returns reference to the builder. SetFile(const std::string & build_file)75 Builder &SetFile(const std::string &build_file) { 76 builder_file_ = build_file; 77 return *this; 78 } 79 80 // Setter method. 81 // @param const std::string & task_type 82 // @return Builder setter method returns reference to the builder. SetTask(const std::string & task_type)83 Builder &SetTask(const std::string &task_type) { 84 if (task_type == "Detection") { 85 builder_task_type_ = TaskType::Detection; 86 } else if (task_type == "Stuff") { 87 builder_task_type_ = TaskType::Stuff; 88 } else if (task_type == "Panoptic") { 89 builder_task_type_ = TaskType::Panoptic; 90 } else if (task_type == "Keypoint") { 91 builder_task_type_ = TaskType::Keypoint; 92 } 93 return *this; 94 } 95 96 // Setter method. 97 // @param int32_t num_workers 98 // @return Builder setter method returns reference to the builder. SetNumWorkers(int32_t num_workers)99 Builder &SetNumWorkers(int32_t num_workers) { 100 builder_num_workers_ = num_workers; 101 return *this; 102 } 103 104 // Setter method. 105 // @param int32_t op_connector_size 106 // @return Builder setter method returns reference to the builder. SetOpConnectorSize(int32_t op_connector_size)107 Builder &SetOpConnectorSize(int32_t op_connector_size) { 108 builder_op_connector_size_ = op_connector_size; 109 return *this; 110 } 111 112 // Setter method. 113 // @param std::shared_ptr<Sampler> sampler 114 // @return Builder setter method returns reference to the builder. SetSampler(std::shared_ptr<SamplerRT> sampler)115 Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { 116 builder_sampler_ = std::move(sampler); 117 return *this; 118 } 119 120 // Setter method. 121 // @param bool do_decode 122 // @return Builder setter method returns reference to the builder. SetDecode(bool do_decode)123 Builder &SetDecode(bool do_decode) { 124 builder_decode_ = do_decode; 125 return *this; 126 } 127 128 // Check validity of input args 129 // @return Status The status code returned 130 Status SanityCheck(); 131 132 // The builder "Build" method creates the final object. 133 // @param std::shared_ptr<CocoOp> *op - DatasetOp 134 // @return Status The status code returned 135 Status Build(std::shared_ptr<CocoOp> *op); 136 137 private: 138 bool builder_decode_; 139 std::string builder_dir_; 140 std::string builder_file_; 141 TaskType builder_task_type_; 142 int32_t builder_num_workers_; 143 int32_t builder_op_connector_size_; 144 int32_t builder_rows_per_buffer_; 145 std::shared_ptr<SamplerRT> builder_sampler_; 146 std::unique_ptr<DataSchema> builder_schema_; 147 }; 148 149 #ifdef ENABLE_PYTHON 150 /// \brief Constructor. 151 /// \param[in] task_type Task type of Coco. 152 /// \param[in] image_folder_path Image folder path of Coco. 153 /// \param[in] annotation_path Annotation json path of Coco. 154 /// \param[in] num_workers Number of workers reading images in parallel. 155 /// \param[in] queue_size Connector queue size. 156 /// \param[in] num_samples Number of samples to read. 157 /// \param[in] decode Whether to decode images. 158 /// \param[in] data_schema The schema of the Coco dataset. 159 /// \param[in] sampler Sampler tells CocoOp what to read. 160 /// \param[in] decrypt - Image decryption function, which accepts the path of the encrypted image file 161 /// and returns the decrypted bytes data. Default: None, no decryption. 162 CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, 163 int32_t num_workers, int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, 164 std::shared_ptr<SamplerRT> sampler, bool extra_metadata, py::function decrypt = py::none()); 165 #else 166 /// \brief Constructor. 167 /// \param[in] task_type Task type of Coco. 168 /// \param[in] image_folder_path Image folder path of Coco. 169 /// \param[in] annotation_path Annotation json path of Coco. 170 /// \param[in] num_workers Number of workers reading images in parallel. 171 /// \param[in] queue_size Connector queue size. 172 /// \param[in] num_samples Number of samples to read. 173 /// \param[in] decode Whether to decode images. 174 /// \param[in] data_schema The schema of the Coco dataset. 175 /// \param[in] sampler Sampler tells CocoOp what to read. 176 CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, 177 int32_t num_workers, int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, 178 std::shared_ptr<SamplerRT> sampler, bool extra_metadata); 179 #endif 180 181 /// \brief Destructor. 182 ~CocoOp() = default; 183 184 /// \brief A print method typically used for debugging. 185 /// \param[out] out The output stream to write output to. 186 /// \param[in] show_all A bool to control if you want to show all info or just a summary. 187 void Print(std::ostream &out, bool show_all) const override; 188 189 /// \param[out] count Output rows number of CocoDataset. 190 Status CountTotalRows(int64_t *count); 191 192 /// \brief Op name getter. 193 /// \return Name of the current Op. Name()194 std::string Name() const override { return "CocoOp"; } 195 196 /// \brief Gets the class indexing. 197 /// \return Status The status code returned. 198 Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override; 199 200 private: 201 /// \brief Load a tensor row according to image id. 202 /// \param[in] row_id Id for this tensor row. 203 /// \param[out] row Image & target read into this tensor row. 204 /// \return Status The status code returned. 205 Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; 206 207 /// \brief Load a tensor row with vector which a vector to a tensor, for "Detection" task. 208 /// \param[in] row_id Id for this tensor row. 209 /// \param[in] image_id Image id. 210 /// \param[in] image Image tensor. 211 /// \param[in] coordinate Coordinate tensor. 212 /// \param[out] row Image & target read into this tensor row. 213 /// \return Status The status code returned. 214 Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, 215 std::shared_ptr<Tensor> coordinate, TensorRow *trow); 216 217 /// \brief Load a tensor row with vector which a vector to a tensor, for "Stuff/Keypoint" task. 218 /// \param[in] row_id Id for this tensor row. 219 /// \param[in] image_id Image id. 220 /// \param[in] image Image tensor. 221 /// \param[in] coordinate Coordinate tensor. 222 /// \param[out] row Image & target read into this tensor row. 223 /// \return Status The status code returned. 224 Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, 225 std::shared_ptr<Tensor> coordinate, TensorRow *trow); 226 227 /// \brief Load a tensor row with vector which a vector to multi-tensor, for "Panoptic" task. 228 /// \param[in] row_id Id for this tensor row. 229 /// \param[in] image_id Image id. 230 /// \param[in] image Image tensor. 231 /// \param[in] coordinate Coordinate tensor. 232 /// \param[out] row Image & target read into this tensor row. 233 /// \return Status The status code returned. 234 Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, 235 std::shared_ptr<Tensor> coordinate, TensorRow *trow); 236 237 /// \brief Load a tensor row with vector which a vector to multi-tensor, for "Captioning" task. 238 /// \param[in] row_id Id for this tensor row. 239 /// \param[in] image_id Image id. 240 /// \param[in] image Image tensor. 241 /// \param[in] captions Captions tensor. 242 /// \param[out] trow Image & target read into this tensor row. 243 /// \return Status The status code returned. 244 Status LoadCaptioningTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, 245 std::shared_ptr<Tensor> captions, TensorRow *trow); 246 247 /// \param[in] path Path to the image file. 248 /// \param[out] tensor Returned tensor. 249 /// \return Status The status code returned. 250 Status ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor) const; 251 252 /// \brief Read annotation from Annotation folder. 253 /// \return Status The status code returned. 254 Status PrepareData() override; 255 256 /// \param[in] image_tree Image tree of json. 257 /// \param[out] image_vec Image id list of json. 258 /// \return Status The status code returned. 259 Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec); 260 261 /// \param[in] categories_tree Categories tree of json. 262 /// \return Status The status code returned. 263 Status CategoriesColumnLoad(const nlohmann::json &categories_tree); 264 265 /// \param[in] categories_tree Categories tree of json. 266 /// \param[in] image_file Current image name in annotation. 267 /// \param[in] id Current unique id of annotation. 268 /// \return Status The status code returned. 269 Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); 270 271 /// \param[in] categories_tree Categories tree of json. 272 /// \param[in] image_file Current image name in annotation. 273 /// \param[in] id Current unique id of annotation. 274 /// \return Status The status code returned. 275 Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); 276 277 /// \param[in] categories_tree Categories tree of json. 278 /// \param[in] image_file Current image name in annotation. 279 /// \param[in] id Current unique id of annotation. 280 /// \return Status The status code returned. 281 Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); 282 283 /// \param[in] categories_tree Categories tree of json. 284 /// \param[in] image_file Current image name in annotation. 285 /// \param[in] image_id Current unique id of annotation. 286 /// \return Status The status code returned. 287 Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, 288 const int32_t &image_id); 289 290 /// \brief Function for finding a caption in annotation_tree. 291 /// \param[in] annotation_tree Annotation tree of json. 292 /// \param[in] image_file Current image name in annotation. 293 /// \param[in] id Current unique id of annotation. 294 /// \return Status The status code returned. 295 Status CaptionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); 296 297 template <typename T> 298 Status SearchNodeInJson(const nlohmann::json &input_tree, std::string node_name, T *output_node); 299 300 /// \brief Private function for computing the assignment of the column name map. 301 /// \return Status The status code returned. 302 Status ComputeColMap() override; 303 304 bool decode_; 305 std::string image_folder_path_; 306 std::string annotation_path_; 307 TaskType task_type_; 308 std::unique_ptr<DataSchema> data_schema_; 309 bool extra_metadata_; 310 311 std::vector<std::string> image_ids_; 312 std::map<int32_t, std::string> image_index_; 313 std::vector<std::pair<std::string, std::vector<int32_t>>> label_index_; 314 std::map<std::string, CoordinateRow> coordinate_map_; 315 std::map<std::string, std::vector<uint32_t>> simple_item_map_; 316 std::map<std::string, std::vector<std::string>> captions_map_; 317 std::set<uint32_t> category_set_; 318 #ifdef ENABLE_PYTHON 319 py::function decrypt_; 320 #endif 321 }; 322 } // namespace dataset 323 } // namespace mindspore 324 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ 325