• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_
19 
20 #include <algorithm>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "minddata/dataset/core/tensor.h"
28 #include "minddata/dataset/engine/data_schema.h"
29 #include "minddata/dataset/engine/datasetops/parallel_op.h"
30 #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
31 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
32 #include "minddata/dataset/util/path.h"
33 #include "minddata/dataset/util/queue.h"
34 #include "minddata/dataset/util/status.h"
35 #include "minddata/dataset/util/wait_post.h"
36 
37 namespace mindspore {
38 namespace dataset {
39 
40 class FakeImageOp : public MappableLeafOp {
41  public:
42   // Constructor.
43   // @param int32_t num_images - Number of generated fake images.
44   // @param const std::vector<int32_t> &image_size - The size of fake image.
45   // @param int32_t num_classes - Number of classes in fake images.
46   // @param int32_t base_seed - A base seed which is used in generating fake image randomly.
47   // @param int32_t num_workers - Number of workers reading images in parallel.
48   // @param int32_t op_connector_size - Connector queue size.
49   // @param std::unique_ptr<DataSchema> data_schema - The schema of the fake image dataset.
50   // @param td::unique_ptr<Sampler> sampler - Sampler tells FakeImageOp what to read.
51   FakeImageOp(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, int32_t base_seed,
52               int32_t num_workers, int32_t op_connector_size, std::unique_ptr<DataSchema> data_schema,
53               std::shared_ptr<SamplerRT> sampler);
54 
55   // Destructor.
56   ~FakeImageOp() = default;
57 
58   // Method derived from RandomAccess Op, enable Sampler to get all ids for each class.
59   // @param std::map<int32_t, std::vector<int64_t>> *cls_ids - Key label, val all ids for this class.
60   // @return Status The status code returned.
61   Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
62 
63   // A print method typically used for debugging.
64   // @param out - The output stream to write output to.
65   // @param show_all - A bool to control if you want to show all info or just a summary.
66   void Print(std::ostream &out, bool show_all) const override;
67 
68   // Function to count the number of samples in the FakeImage dataset.
69   // @return Number of images.
GetTotalRows()70   int64_t GetTotalRows() const { return num_images_; }
71 
72   // Op name getter.
73   // @return Name of the current Op.
Name()74   std::string Name() const override { return "FakeImageOp"; }
75 
76   // Get a image from index
77   // @param int32_t index - Generate one image according to index.
78   Status GetItem(int32_t index);
79 
80  private:
81   // Load a tensor row according to a lable_list.
82   // @param row_id_type row_id - Id for this tensor row.
83   // @param TensorRow *row - Image & label read into this tensor row.
84   // @return Status The status code returned.
85   Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
86 
87   // Generate all labels of FakeImage dataset
88   // @return Status The status code returned.
89   Status PrepareData();
90 
91   // Private function for computing the assignment of the column name map.
92   // @return Status The status code returned.
93   Status ComputeColMap() override;
94 
95   int32_t num_images_;
96   int32_t base_seed_;
97   std::vector<int> image_size_;
98   int32_t num_classes_;
99 
100   std::unique_ptr<DataSchema> data_schema_;
101 
102   int32_t image_total_size_;
103   std::vector<uint32_t> label_list_;
104   std::vector<std::shared_ptr<Tensor>> image_tensor_;
105   std::mt19937 rand_gen_;
106   std::mutex access_mutex_;
107 };
108 
109 }  // namespace dataset
110 }  // namespace mindspore
111 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_
112