• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/fake_image_op.h"
17 
18 #include "minddata/dataset/core/config_manager.h"
19 #include "minddata/dataset/core/tensor_shape.h"
20 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
21 #include "minddata/dataset/engine/execution_tree.h"
22 #include "utils/ms_utils.h"
23 
24 namespace mindspore {
25 namespace dataset {
FakeImageOp(int32_t num_images,const std::vector<int32_t> & image_size,int32_t num_classes,int32_t base_seed,int32_t num_workers,int32_t op_connector_size,std::unique_ptr<DataSchema> data_schema,std::shared_ptr<SamplerRT> sampler)26 FakeImageOp::FakeImageOp(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
27                          int32_t base_seed, int32_t num_workers, int32_t op_connector_size,
28                          std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
29     : MappableLeafOp(num_workers, op_connector_size, std::move(sampler)),
30       num_images_(num_images),
31       base_seed_(base_seed),
32       image_size_(image_size),
33       num_classes_(num_classes),
34       data_schema_(std::move(data_schema)),
35       image_total_size_(0),
36       label_list_({}),
37       image_tensor_({}) {}
38 
39 // Load 1 TensorRow (image, label) using 1 trow.
LoadTensorRow(row_id_type row_id,TensorRow * trow)40 Status FakeImageOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
41   RETURN_UNEXPECTED_IF_NULL(trow);
42   std::shared_ptr<Tensor> image, label;
43 
44   auto images_buf = std::make_unique<double[]>(image_total_size_);
45   auto pixels = &images_buf[0];
46   {
47     std::unique_lock<std::mutex> lock(access_mutex_);
48     if (image_tensor_[row_id] == nullptr) {
49       rand_gen_.seed(base_seed_ + row_id);  // set seed for random generator.
50       std::normal_distribution<double> distribution(0.0, 1.0);
51       for (int i = 0; i < image_total_size_; ++i) {
52         pixels[i] = distribution(rand_gen_);  // generate the Gaussian distribution pixel.
53         if (pixels[i] < 0) {
54           pixels[i] = 0;
55         }
56         if (pixels[i] > 255) {
57           pixels[i] = 255;
58         }
59       }
60       TensorShape img_tensor_shape = TensorShape({image_size_[0], image_size_[1], image_size_[2]});
61       RETURN_IF_NOT_OK(Tensor::CreateFromMemory(img_tensor_shape, data_schema_->Column(0).Type(),
62                                                 reinterpret_cast<unsigned char *>(pixels), &image));
63       RETURN_IF_NOT_OK(Tensor::CreateFromTensor(image, &image_tensor_[row_id]));
64     } else {
65       RETURN_IF_NOT_OK(Tensor::CreateFromTensor(image_tensor_[row_id], &image));
66     }
67   }
68   RETURN_IF_NOT_OK(Tensor::CreateScalar(label_list_[row_id], &label));
69   (*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
70   return Status::OK();
71 }
72 
73 // A print method typically used for debugging.
Print(std::ostream & out,bool show_all) const74 void FakeImageOp::Print(std::ostream &out, bool show_all) const {
75   if (!show_all) {
76     // Call the super class for displaying any common 1-liner info.
77     ParallelOp::Print(out, show_all);
78   } else {
79     // Call the super class for displaying any common detailed info.
80     ParallelOp::Print(out, show_all);
81     // Then show any custom derived-internal stuff.
82     out << "\nNumber of images: " << num_images_ << "\nNumber of classes: " << num_classes_
83         << "\nBase seed: " << base_seed_ << "\n\n";
84   }
85 }
86 
87 // Derived from RandomAccessOp.
GetClassIds(std::map<int32_t,std::vector<int64_t>> * cls_ids) const88 Status FakeImageOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
89   if (cls_ids == nullptr || !cls_ids->empty() || label_list_.empty()) {
90     if (label_list_.empty()) {
91       RETURN_STATUS_UNEXPECTED(
92         "[Internal ERROR] No image found in dataset. Check if image was generated successfully.");
93     } else {
94       RETURN_STATUS_UNEXPECTED(
95         "[Internal ERROR] Map for storing image-index pair is nullptr or has been set in other place, "
96         "it must be empty before using GetClassIds.");
97     }
98   }
99   for (size_t i = 0; i < label_list_.size(); ++i) {
100     (*cls_ids)[label_list_[i]].push_back(i);
101   }
102   for (auto &pr : (*cls_ids)) {
103     pr.second.shrink_to_fit();
104   }
105   return Status::OK();
106 }
107 
GetItem(int32_t index)108 Status FakeImageOp::GetItem(int32_t index) {
109   // generate one target label according to index and save it in label_list_.
110   rand_gen_.seed(base_seed_ + index);  // set seed for random generator.
111   std::uniform_int_distribution<int32_t> dist(0, num_classes_ - 1);
112   uint32_t target = dist(rand_gen_);  // generate the target.
113   label_list_.emplace_back(target);
114 
115   return Status::OK();
116 }
117 
PrepareData()118 Status FakeImageOp::PrepareData() {
119   // FakeImage generate image with Gaussian distribution.
120   image_total_size_ = image_size_[0] * image_size_[1] * image_size_[2];
121 
122   for (size_t i = 0; i < num_images_; ++i) {
123     RETURN_IF_NOT_OK(GetItem(i));
124   }
125 
126   label_list_.shrink_to_fit();
127   num_rows_ = label_list_.size();
128   CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "Invalid data, generate fake data failed, please check dataset API.");
129   image_tensor_.clear();
130   image_tensor_.resize(num_rows_);
131   return Status::OK();
132 }
133 
ComputeColMap()134 Status FakeImageOp::ComputeColMap() {
135   // Extract the column name mapping from the schema and save it in the class.
136   if (column_name_id_map_.empty()) {
137     RETURN_IF_NOT_OK(data_schema_->GetColumnNameMap(&(column_name_id_map_)));
138   } else {
139     MS_LOG(WARNING) << "Column name map is already set!";
140   }
141   return Status::OK();
142 }
143 }  // namespace dataset
144 }  // namespace mindspore
145