1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_ 19 20 #include <algorithm> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "minddata/dataset/core/tensor.h" 28 #include "minddata/dataset/engine/data_schema.h" 29 #include "minddata/dataset/engine/datasetops/parallel_op.h" 30 #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" 31 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 32 #include "minddata/dataset/util/path.h" 33 #include "minddata/dataset/util/queue.h" 34 #include "minddata/dataset/util/status.h" 35 #include "minddata/dataset/util/wait_post.h" 36 37 namespace mindspore { 38 namespace dataset { 39 40 class FakeImageOp : public MappableLeafOp { 41 public: 42 // Constructor. 43 // @param int32_t num_images - Number of generated fake images. 44 // @param const std::vector<int32_t> &image_size - The size of fake image. 45 // @param int32_t num_classes - Number of classes in fake images. 46 // @param int32_t base_seed - A base seed which is used in generating fake image randomly. 47 // @param int32_t num_workers - Number of workers reading images in parallel. 48 // @param int32_t op_connector_size - Connector queue size. 49 // @param std::unique_ptr<DataSchema> data_schema - The schema of the fake image dataset. 50 // @param td::unique_ptr<Sampler> sampler - Sampler tells FakeImageOp what to read. 51 FakeImageOp(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, int32_t base_seed, 52 int32_t num_workers, int32_t op_connector_size, std::unique_ptr<DataSchema> data_schema, 53 std::shared_ptr<SamplerRT> sampler); 54 55 // Destructor. 56 ~FakeImageOp() = default; 57 58 // Method derived from RandomAccess Op, enable Sampler to get all ids for each class. 59 // @param std::map<int32_t, std::vector<int64_t>> *cls_ids - Key label, val all ids for this class. 60 // @return Status The status code returned. 61 Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; 62 63 // A print method typically used for debugging. 64 // @param out - The output stream to write output to. 65 // @param show_all - A bool to control if you want to show all info or just a summary. 66 void Print(std::ostream &out, bool show_all) const override; 67 68 // Function to count the number of samples in the FakeImage dataset. 69 // @return Number of images. GetTotalRows()70 int64_t GetTotalRows() const { return num_images_; } 71 72 // Op name getter. 73 // @return Name of the current Op. Name()74 std::string Name() const override { return "FakeImageOp"; } 75 76 // Get a image from index 77 // @param int32_t index - Generate one image according to index. 78 Status GetItem(int32_t index); 79 80 private: 81 // Load a tensor row according to a lable_list. 82 // @param row_id_type row_id - Id for this tensor row. 83 // @param TensorRow *row - Image & label read into this tensor row. 84 // @return Status The status code returned. 85 Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; 86 87 // Generate all labels of FakeImage dataset 88 // @return Status The status code returned. 89 Status PrepareData(); 90 91 // Private function for computing the assignment of the column name map. 92 // @return Status The status code returned. 93 Status ComputeColMap() override; 94 95 int32_t num_images_; 96 int32_t base_seed_; 97 std::vector<int> image_size_; 98 int32_t num_classes_; 99 100 std::unique_ptr<DataSchema> data_schema_; 101 102 int32_t image_total_size_; 103 std::vector<uint32_t> label_list_; 104 std::vector<std::shared_ptr<Tensor>> image_tensor_; 105 std::mt19937 rand_gen_; 106 std::mutex access_mutex_; 107 }; 108 109 } // namespace dataset 110 } // namespace mindspore 111 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_ 112