• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/stl10_op.h"
17 
18 #include <algorithm>
19 #include <fstream>
20 #include <iomanip>
21 #include <set>
22 #include <utility>
23 
24 #include "include/common/debug/common.h"
25 #include "minddata/dataset/core/config_manager.h"
26 #include "minddata/dataset/core/tensor_shape.h"
27 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
28 #include "minddata/dataset/engine/execution_tree.h"
29 #include "utils/file_utils.h"
30 #include "utils/ms_utils.h"
31 
32 namespace mindspore {
33 namespace dataset {
34 constexpr uint32_t kSTLImageRows = 96;
35 constexpr uint32_t kSTLImageCols = 96;
36 constexpr uint32_t kSTLImageChannel = 3;
37 constexpr uint32_t kSTLImageSize = kSTLImageRows * kSTLImageCols * kSTLImageChannel;
38 
STL10Op(const std::string & usage,int32_t num_workers,const std::string & folder_path,int32_t queue_size,std::unique_ptr<DataSchema> data_schema,std::shared_ptr<SamplerRT> sampler)39 STL10Op::STL10Op(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size,
40                  std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
41     : MappableLeafOp(num_workers, queue_size, std::move(sampler)),
42       folder_path_(folder_path),
43       usage_(usage),
44       data_schema_(std::move(data_schema)),
45       image_path_({}),
46       label_path_({}) {}
47 
LoadTensorRow(row_id_type index,TensorRow * trow)48 Status STL10Op::LoadTensorRow(row_id_type index, TensorRow *trow) {
49   RETURN_UNEXPECTED_IF_NULL(trow);
50   std::pair<std::shared_ptr<Tensor>, int32_t> stl10_pair = stl10_image_label_pairs_[index];
51   std::shared_ptr<Tensor> image, label;
52   // make a copy of cached tensor.
53   RETURN_IF_NOT_OK(Tensor::CreateFromTensor(stl10_pair.first, &image));
54   RETURN_IF_NOT_OK(Tensor::CreateScalar(stl10_pair.second, &label));
55 
56   (*trow) = TensorRow(index, {std::move(image), std::move(label)});
57   trow->setPath({image_path_[index], label_path_[index]});
58 
59   return Status::OK();
60 }
61 
GetClassIds(std::map<int32_t,std::vector<int64_t>> * cls_ids) const62 Status STL10Op::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
63   if (cls_ids == nullptr || !cls_ids->empty() || stl10_image_label_pairs_.empty()) {
64     if (stl10_image_label_pairs_.empty()) {
65       RETURN_STATUS_UNEXPECTED("No image found in dataset. Check if image was generated successfully.");
66     } else {
67       RETURN_STATUS_UNEXPECTED(
68         "[Internal ERROR] Map for containing image-index pair is nullptr or has been set in other place, "
69         "it must be empty before using GetClassIds.");
70     }
71   }
72   for (size_t i = 0; i < stl10_image_label_pairs_.size(); ++i) {
73     (*cls_ids)[stl10_image_label_pairs_[i].second].push_back(i);
74   }
75   for (auto &pair : (*cls_ids)) {
76     pair.second.shrink_to_fit();
77   }
78   return Status::OK();
79 }
80 
Print(std::ostream & out,bool show_all) const81 void STL10Op::Print(std::ostream &out, bool show_all) const {
82   if (!show_all) {
83     // Call the super class for displaying any common 1-liner info.
84     ParallelOp::Print(out, show_all);
85     // Then show any custom derived-internal 1-liner info for this op.
86     out << "\n";
87   } else {
88     // Call the super class for displaying any common detailed info.
89     ParallelOp::Print(out, show_all);
90     // Then show any custom derived-internal stuff.
91     out << "\nNumber of rows: " << num_rows_ << "\nSTL10 directory: " << folder_path_ << "\n\n";
92   }
93 }
94 
WalkAllFiles()95 Status STL10Op::WalkAllFiles() {
96   auto real_dataset_dir = FileUtils::GetRealPath(folder_path_.c_str());
97   CHECK_FAIL_RETURN_UNEXPECTED(real_dataset_dir.has_value(),
98                                "Invalid file, get real path failed, path: " + folder_path_);
99   Path root_dir(real_dataset_dir.value());
100 
101   const Path train_data_file("train_X.bin");
102   const Path train_label_file("train_y.bin");
103   const Path test_data_file("test_X.bin");
104   const Path test_label_file("test_y.bin");
105   const Path unlabeled_data_file("unlabeled_X.bin");
106 
107   bool use_train = false;
108   bool use_test = false;
109   bool use_unlabeled = false;
110 
111   if (usage_ == "train") {
112     use_train = true;
113   } else if (usage_ == "test") {
114     use_test = true;
115   } else if (usage_ == "unlabeled") {
116     use_unlabeled = true;
117   } else if (usage_ == "train+unlabeled") {
118     use_train = true;
119     use_unlabeled = true;
120   } else if (usage_ == "all") {
121     use_train = true;
122     use_test = true;
123     use_unlabeled = true;
124   } else {
125     RETURN_STATUS_UNEXPECTED(
126       "Invalid parameter, usage should be \"train\", \"test\", \"unlabeled\", "
127       "\"train+unlabeled\", \"all\", got " +
128       usage_);
129   }
130 
131   if (use_train) {
132     Path train_data_path = root_dir / train_data_file;
133     Path train_label_path = root_dir / train_label_file;
134     CHECK_FAIL_RETURN_UNEXPECTED(
135       train_data_path.Exists() && !train_data_path.IsDirectory(),
136       "Invalid file, failed to find STL10 " + usage_ + " data file: " + train_data_path.ToString());
137     CHECK_FAIL_RETURN_UNEXPECTED(
138       train_label_path.Exists() && !train_label_path.IsDirectory(),
139       "Invalid file, failed to find STL10 " + usage_ + " label file: " + train_label_path.ToString());
140     image_names_.push_back(train_data_path.ToString());
141     label_names_.push_back(train_label_path.ToString());
142     MS_LOG(INFO) << "STL10 operator found train data file " << train_data_path.ToString() << ".";
143     MS_LOG(INFO) << "STL10 operator found train label file " << train_label_path.ToString() << ".";
144   }
145 
146   if (use_test) {
147     Path test_data_path = root_dir / test_data_file;
148     Path test_label_path = root_dir / test_label_file;
149     CHECK_FAIL_RETURN_UNEXPECTED(
150       test_data_path.Exists() && !test_data_path.IsDirectory(),
151       "Invalid file, failed to find STL10 " + usage_ + " data file: " + test_data_path.ToString());
152     CHECK_FAIL_RETURN_UNEXPECTED(
153       test_label_path.Exists() && !test_label_path.IsDirectory(),
154       "Invalid file, failed to find STL10 " + usage_ + " label file: " + test_label_path.ToString());
155     image_names_.push_back(test_data_path.ToString());
156     label_names_.push_back(test_label_path.ToString());
157     MS_LOG(INFO) << "STL10 operator found test data file " << test_data_path.ToString() << ".";
158     MS_LOG(INFO) << "STL10 operator found test label file " << test_label_path.ToString() << ".";
159   }
160 
161   if (use_unlabeled) {
162     Path unlabeled_data_path = root_dir / unlabeled_data_file;
163     CHECK_FAIL_RETURN_UNEXPECTED(
164       unlabeled_data_path.Exists() && !unlabeled_data_path.IsDirectory(),
165       "Invalid file, failed to find STL10 " + usage_ + " data file: " + unlabeled_data_path.ToString());
166     image_names_.push_back(unlabeled_data_path.ToString());
167     MS_LOG(INFO) << "STL10 operator found unlabeled data file " << unlabeled_data_path.ToString() << ".";
168   }
169 
170   std::sort(image_names_.begin(), image_names_.end());
171   std::sort(label_names_.begin(), label_names_.end());
172 
173   return Status::OK();
174 }
175 
ParseSTLData()176 Status STL10Op::ParseSTLData() {
177   // STL10 contains 5 files, *_X.bin are image files, *_y.bin are labels.
178   // training files contain 5k images and testing files contain 8K examples.
179   // unlabeled file contain 10k images and they DO NOT have labels (i.e. no "unlabeled_y.bin" file).
180   for (size_t i = 0; i < image_names_.size(); ++i) {
181     std::ifstream image_reader, label_reader;
182     if (image_names_[i].find("unlabeled") == std::string::npos) {
183       image_reader.open(image_names_[i], std::ios::binary | std::ios::ate);
184       label_reader.open(label_names_[i], std::ios::binary | std::ios::ate);
185 
186       Status s = ReadImageAndLabel(&image_reader, &label_reader, i);
187       // Close the readers.
188       image_reader.close();
189       label_reader.close();
190 
191       RETURN_IF_NOT_OK(s);
192     } else {  // unlabeled data -> no labels.
193       image_reader.open(image_names_[i], std::ios::binary | std::ios::ate);
194 
195       Status s = ReadImageAndLabel(&image_reader, nullptr, i);
196       // Close the readers.
197       image_reader.close();
198 
199       RETURN_IF_NOT_OK(s);
200     }
201   }
202   stl10_image_label_pairs_.shrink_to_fit();
203   num_rows_ = stl10_image_label_pairs_.size();
204   if (num_rows_ == 0) {
205     RETURN_STATUS_UNEXPECTED(
206       "Invalid data, no valid data matching the dataset API STL10Dataset. Please check file path or dataset API.");
207   }
208 
209   return Status::OK();
210 }
211 
ReadImageAndLabel(std::ifstream * image_reader,std::ifstream * label_reader,size_t index)212 Status STL10Op::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) {
213   RETURN_UNEXPECTED_IF_NULL(image_reader);
214 
215   Path image_path(image_names_[index]);
216   bool has_label_file = image_path.Basename().find("unlabeled") == std::string::npos;
217 
218   std::streamsize image_size = image_reader->tellg();
219 
220   image_reader->seekg(0, std::ios::beg);
221   auto images_buf = std::make_unique<char[]>(image_size);
222   auto labels_buf = std::make_unique<char[]>(0);
223 
224   if (images_buf == nullptr) {
225     std::string err_msg = "Failed to allocate memory for STL10 buffer.";
226     MS_LOG(ERROR) << err_msg.c_str();
227     RETURN_STATUS_UNEXPECTED(err_msg);
228   }
229 
230   uint64_t num_images = static_cast<uint64_t>(image_size / kSTLImageSize);
231   (void)image_reader->read(images_buf.get(), image_size);
232   if (image_reader->fail()) {
233     RETURN_STATUS_UNEXPECTED("Invalid file, failed to read image: " + image_names_[index] +
234                              ", size:" + std::to_string(kSTLImageSize * num_images));
235   }
236 
237   if (has_label_file) {
238     RETURN_UNEXPECTED_IF_NULL(label_reader);
239     std::streamsize label_size = label_reader->tellg();
240     if (static_cast<uint64_t>(label_size) != num_images) {
241       RETURN_STATUS_UNEXPECTED("Invalid file, error in " + label_names_[index] +
242                                ": the number of labels is not equal to the number of images in " + image_names_[index] +
243                                "! Please check the file integrity!");
244     }
245 
246     label_reader->seekg(0, std::ios::beg);
247     labels_buf = std::make_unique<char[]>(label_size);
248     if (labels_buf == nullptr) {
249       std::string err_msg = "Failed to allocate memory for STL10 buffer.";
250       MS_LOG(ERROR) << err_msg.c_str();
251       RETURN_STATUS_UNEXPECTED(err_msg);
252     }
253 
254     (void)label_reader->read(labels_buf.get(), label_size);
255     if (label_reader->fail()) {
256       RETURN_STATUS_UNEXPECTED("Invalid file, failed to read label:" + label_names_[index] +
257                                ", size: " + std::to_string(num_images));
258     }
259   }
260 
261   for (int64_t j = 0; j < num_images; ++j) {
262     int32_t label = (has_label_file ? labels_buf[j] - 1 : -1);
263 
264     std::shared_ptr<Tensor> image_tensor;
265     RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({kSTLImageRows, kSTLImageCols, kSTLImageChannel}),
266                                          data_schema_->Column(0).Type(), &image_tensor));
267 
268     auto iter = image_tensor->begin<uint8_t>();
269     uint64_t total_pix = kSTLImageRows * kSTLImageCols;
270     // stl10: Column major order.
271     for (uint64_t count = 0, pix = 0; count < total_pix; count++) {
272       if (count % kSTLImageRows == 0) {
273         pix = count / kSTLImageRows;
274       }
275 
276       for (int ch = 0; ch < kSTLImageChannel; ch++) {
277         *iter = images_buf[j * kSTLImageSize + ch * total_pix + pix];
278         iter++;
279       }
280       pix += kSTLImageRows;
281     }
282 
283     (void)stl10_image_label_pairs_.emplace_back(std::make_pair(image_tensor, label));
284     image_path_.push_back(image_names_[index]);
285     label_path_.push_back(has_label_file ? label_names_[index] : "no label");
286   }
287 
288   return Status::OK();
289 }
290 
PrepareData()291 Status STL10Op::PrepareData() {
292   RETURN_IF_NOT_OK(this->WalkAllFiles());
293   RETURN_IF_NOT_OK(this->ParseSTLData());  // Parse stl10 data and get num rows, blocking.
294 
295   return Status::OK();
296 }
297 
CountTotalRows(const std::string & dir,const std::string & usage,int64_t * count)298 Status STL10Op::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
299   RETURN_UNEXPECTED_IF_NULL(count);
300   // the logic of counting the number of samples is copied from ParseSTLData().
301   const int64_t num_samples = 0;
302   const int64_t start_index = 0;
303   auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
304   auto schema = std::make_unique<DataSchema>();
305   RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
306   TensorShape scalar = TensorShape::CreateScalar();
307   RETURN_IF_NOT_OK(
308     schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
309   std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
310   int32_t num_workers = cfg->num_parallel_workers();
311   int32_t op_connect_size = cfg->op_connector_size();
312   auto op = std::make_shared<STL10Op>(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler));
313 
314   RETURN_IF_NOT_OK(op->WalkAllFiles());
315 
316   bool use_train = false;
317   bool use_test = false;
318   bool use_unlabeled = false;
319 
320   if (usage == "train") {
321     use_train = true;
322   } else if (usage == "test") {
323     use_test = true;
324   } else if (usage == "unlabeled") {
325     use_unlabeled = true;
326   } else if (usage == "train+unlabeled") {
327     use_train = true;
328     use_unlabeled = true;
329   } else if (usage == "all") {
330     use_train = true;
331     use_test = true;
332     use_unlabeled = true;
333   } else {
334     RETURN_STATUS_UNEXPECTED(
335       "Invalid parameter, usage should be \"train\", \"test\", \"unlabeled\", "
336       "\"train+unlabeled\", \"all\", got " +
337       usage);
338   }
339 
340   *count = 0;
341   uint64_t num_stl10_records = 0;
342   uint64_t total_image_size = 0;
343 
344   if (use_train) {
345     uint32_t index = (usage == "all" ? 1 : 0);
346     Path train_image_path(op->image_names_[index]);
347     CHECK_FAIL_RETURN_UNEXPECTED(train_image_path.Exists() && !train_image_path.IsDirectory(),
348                                  "Invalid file, failed to open stl10 file: " + train_image_path.ToString());
349 
350     std::ifstream train_image_file(train_image_path.ToString(), std::ios::binary | std::ios::ate);
351     CHECK_FAIL_RETURN_UNEXPECTED(train_image_file.is_open(),
352                                  "Invalid file, failed to open stl10 file: " + train_image_path.ToString());
353     total_image_size += static_cast<uint64_t>(train_image_file.tellg());
354 
355     train_image_file.close();
356   }
357 
358   if (use_test) {
359     uint32_t index = 0;
360     Path test_image_path(op->image_names_[index]);
361     CHECK_FAIL_RETURN_UNEXPECTED(test_image_path.Exists() && !test_image_path.IsDirectory(),
362                                  "Invalid file, failed to open stl10 file: " + test_image_path.ToString());
363 
364     std::ifstream test_image_file(test_image_path.ToString(), std::ios::binary | std::ios::ate);
365     CHECK_FAIL_RETURN_UNEXPECTED(test_image_file.is_open(),
366                                  "Invalid file, failed to open stl10 file: " + test_image_path.ToString());
367     total_image_size += static_cast<uint64_t>(test_image_file.tellg());
368 
369     test_image_file.close();
370   }
371 
372   if (use_unlabeled) {
373     uint32_t index = (usage == "unlabeled" ? 0 : (usage == "train+unlabeled" ? 1 : 2));
374     Path unlabeled_image_path(op->image_names_[index]);
375     CHECK_FAIL_RETURN_UNEXPECTED(unlabeled_image_path.Exists() && !unlabeled_image_path.IsDirectory(),
376                                  "Invalid file, failed to open stl10 file: " + unlabeled_image_path.ToString());
377 
378     std::ifstream unlabeled_image_file(unlabeled_image_path.ToString(), std::ios::binary | std::ios::ate);
379     CHECK_FAIL_RETURN_UNEXPECTED(unlabeled_image_file.is_open(),
380                                  "Invalid file, failed to open stl10 file: " + unlabeled_image_path.ToString());
381     total_image_size += static_cast<uint64_t>(unlabeled_image_file.tellg());
382 
383     unlabeled_image_file.close();
384   }
385 
386   num_stl10_records = static_cast<uint64_t>(total_image_size / kSTLImageSize);
387 
388   *count = *count + num_stl10_records;
389 
390   return Status::OK();
391 }
392 
ComputeColMap()393 Status STL10Op::ComputeColMap() {
394   // set the column Name map (base class field).
395   if (column_name_id_map_.empty()) {
396     for (uint32_t i = 0; i < data_schema_->NumColumns(); ++i) {
397       column_name_id_map_[data_schema_->Column(i).Name()] = i;
398     }
399   } else {
400     MS_LOG(WARNING) << "Column Name map is already set!";
401   }
402   return Status::OK();
403 }
404 }  // namespace dataset
405 }  // namespace mindspore
406