1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/stl10_op.h"
17
18 #include <algorithm>
19 #include <fstream>
20 #include <iomanip>
21 #include <set>
22 #include <utility>
23
24 #include "include/common/debug/common.h"
25 #include "minddata/dataset/core/config_manager.h"
26 #include "minddata/dataset/core/tensor_shape.h"
27 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
28 #include "minddata/dataset/engine/execution_tree.h"
29 #include "utils/file_utils.h"
30 #include "utils/ms_utils.h"
31
32 namespace mindspore {
33 namespace dataset {
34 constexpr uint32_t kSTLImageRows = 96;
35 constexpr uint32_t kSTLImageCols = 96;
36 constexpr uint32_t kSTLImageChannel = 3;
37 constexpr uint32_t kSTLImageSize = kSTLImageRows * kSTLImageCols * kSTLImageChannel;
38
STL10Op(const std::string & usage,int32_t num_workers,const std::string & folder_path,int32_t queue_size,std::unique_ptr<DataSchema> data_schema,std::shared_ptr<SamplerRT> sampler)39 STL10Op::STL10Op(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size,
40 std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
41 : MappableLeafOp(num_workers, queue_size, std::move(sampler)),
42 folder_path_(folder_path),
43 usage_(usage),
44 data_schema_(std::move(data_schema)),
45 image_path_({}),
46 label_path_({}) {}
47
LoadTensorRow(row_id_type index,TensorRow * trow)48 Status STL10Op::LoadTensorRow(row_id_type index, TensorRow *trow) {
49 RETURN_UNEXPECTED_IF_NULL(trow);
50 std::pair<std::shared_ptr<Tensor>, int32_t> stl10_pair = stl10_image_label_pairs_[index];
51 std::shared_ptr<Tensor> image, label;
52 // make a copy of cached tensor.
53 RETURN_IF_NOT_OK(Tensor::CreateFromTensor(stl10_pair.first, &image));
54 RETURN_IF_NOT_OK(Tensor::CreateScalar(stl10_pair.second, &label));
55
56 (*trow) = TensorRow(index, {std::move(image), std::move(label)});
57 trow->setPath({image_path_[index], label_path_[index]});
58
59 return Status::OK();
60 }
61
GetClassIds(std::map<int32_t,std::vector<int64_t>> * cls_ids) const62 Status STL10Op::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
63 if (cls_ids == nullptr || !cls_ids->empty() || stl10_image_label_pairs_.empty()) {
64 if (stl10_image_label_pairs_.empty()) {
65 RETURN_STATUS_UNEXPECTED("No image found in dataset. Check if image was generated successfully.");
66 } else {
67 RETURN_STATUS_UNEXPECTED(
68 "[Internal ERROR] Map for containing image-index pair is nullptr or has been set in other place, "
69 "it must be empty before using GetClassIds.");
70 }
71 }
72 for (size_t i = 0; i < stl10_image_label_pairs_.size(); ++i) {
73 (*cls_ids)[stl10_image_label_pairs_[i].second].push_back(i);
74 }
75 for (auto &pair : (*cls_ids)) {
76 pair.second.shrink_to_fit();
77 }
78 return Status::OK();
79 }
80
Print(std::ostream & out,bool show_all) const81 void STL10Op::Print(std::ostream &out, bool show_all) const {
82 if (!show_all) {
83 // Call the super class for displaying any common 1-liner info.
84 ParallelOp::Print(out, show_all);
85 // Then show any custom derived-internal 1-liner info for this op.
86 out << "\n";
87 } else {
88 // Call the super class for displaying any common detailed info.
89 ParallelOp::Print(out, show_all);
90 // Then show any custom derived-internal stuff.
91 out << "\nNumber of rows: " << num_rows_ << "\nSTL10 directory: " << folder_path_ << "\n\n";
92 }
93 }
94
WalkAllFiles()95 Status STL10Op::WalkAllFiles() {
96 auto real_dataset_dir = FileUtils::GetRealPath(folder_path_.c_str());
97 CHECK_FAIL_RETURN_UNEXPECTED(real_dataset_dir.has_value(),
98 "Invalid file, get real path failed, path: " + folder_path_);
99 Path root_dir(real_dataset_dir.value());
100
101 const Path train_data_file("train_X.bin");
102 const Path train_label_file("train_y.bin");
103 const Path test_data_file("test_X.bin");
104 const Path test_label_file("test_y.bin");
105 const Path unlabeled_data_file("unlabeled_X.bin");
106
107 bool use_train = false;
108 bool use_test = false;
109 bool use_unlabeled = false;
110
111 if (usage_ == "train") {
112 use_train = true;
113 } else if (usage_ == "test") {
114 use_test = true;
115 } else if (usage_ == "unlabeled") {
116 use_unlabeled = true;
117 } else if (usage_ == "train+unlabeled") {
118 use_train = true;
119 use_unlabeled = true;
120 } else if (usage_ == "all") {
121 use_train = true;
122 use_test = true;
123 use_unlabeled = true;
124 } else {
125 RETURN_STATUS_UNEXPECTED(
126 "Invalid parameter, usage should be \"train\", \"test\", \"unlabeled\", "
127 "\"train+unlabeled\", \"all\", got " +
128 usage_);
129 }
130
131 if (use_train) {
132 Path train_data_path = root_dir / train_data_file;
133 Path train_label_path = root_dir / train_label_file;
134 CHECK_FAIL_RETURN_UNEXPECTED(
135 train_data_path.Exists() && !train_data_path.IsDirectory(),
136 "Invalid file, failed to find STL10 " + usage_ + " data file: " + train_data_path.ToString());
137 CHECK_FAIL_RETURN_UNEXPECTED(
138 train_label_path.Exists() && !train_label_path.IsDirectory(),
139 "Invalid file, failed to find STL10 " + usage_ + " label file: " + train_label_path.ToString());
140 image_names_.push_back(train_data_path.ToString());
141 label_names_.push_back(train_label_path.ToString());
142 MS_LOG(INFO) << "STL10 operator found train data file " << train_data_path.ToString() << ".";
143 MS_LOG(INFO) << "STL10 operator found train label file " << train_label_path.ToString() << ".";
144 }
145
146 if (use_test) {
147 Path test_data_path = root_dir / test_data_file;
148 Path test_label_path = root_dir / test_label_file;
149 CHECK_FAIL_RETURN_UNEXPECTED(
150 test_data_path.Exists() && !test_data_path.IsDirectory(),
151 "Invalid file, failed to find STL10 " + usage_ + " data file: " + test_data_path.ToString());
152 CHECK_FAIL_RETURN_UNEXPECTED(
153 test_label_path.Exists() && !test_label_path.IsDirectory(),
154 "Invalid file, failed to find STL10 " + usage_ + " label file: " + test_label_path.ToString());
155 image_names_.push_back(test_data_path.ToString());
156 label_names_.push_back(test_label_path.ToString());
157 MS_LOG(INFO) << "STL10 operator found test data file " << test_data_path.ToString() << ".";
158 MS_LOG(INFO) << "STL10 operator found test label file " << test_label_path.ToString() << ".";
159 }
160
161 if (use_unlabeled) {
162 Path unlabeled_data_path = root_dir / unlabeled_data_file;
163 CHECK_FAIL_RETURN_UNEXPECTED(
164 unlabeled_data_path.Exists() && !unlabeled_data_path.IsDirectory(),
165 "Invalid file, failed to find STL10 " + usage_ + " data file: " + unlabeled_data_path.ToString());
166 image_names_.push_back(unlabeled_data_path.ToString());
167 MS_LOG(INFO) << "STL10 operator found unlabeled data file " << unlabeled_data_path.ToString() << ".";
168 }
169
170 std::sort(image_names_.begin(), image_names_.end());
171 std::sort(label_names_.begin(), label_names_.end());
172
173 return Status::OK();
174 }
175
ParseSTLData()176 Status STL10Op::ParseSTLData() {
177 // STL10 contains 5 files, *_X.bin are image files, *_y.bin are labels.
178 // training files contain 5k images and testing files contain 8K examples.
179 // unlabeled file contain 10k images and they DO NOT have labels (i.e. no "unlabeled_y.bin" file).
180 for (size_t i = 0; i < image_names_.size(); ++i) {
181 std::ifstream image_reader, label_reader;
182 if (image_names_[i].find("unlabeled") == std::string::npos) {
183 image_reader.open(image_names_[i], std::ios::binary | std::ios::ate);
184 label_reader.open(label_names_[i], std::ios::binary | std::ios::ate);
185
186 Status s = ReadImageAndLabel(&image_reader, &label_reader, i);
187 // Close the readers.
188 image_reader.close();
189 label_reader.close();
190
191 RETURN_IF_NOT_OK(s);
192 } else { // unlabeled data -> no labels.
193 image_reader.open(image_names_[i], std::ios::binary | std::ios::ate);
194
195 Status s = ReadImageAndLabel(&image_reader, nullptr, i);
196 // Close the readers.
197 image_reader.close();
198
199 RETURN_IF_NOT_OK(s);
200 }
201 }
202 stl10_image_label_pairs_.shrink_to_fit();
203 num_rows_ = stl10_image_label_pairs_.size();
204 if (num_rows_ == 0) {
205 RETURN_STATUS_UNEXPECTED(
206 "Invalid data, no valid data matching the dataset API STL10Dataset. Please check file path or dataset API.");
207 }
208
209 return Status::OK();
210 }
211
ReadImageAndLabel(std::ifstream * image_reader,std::ifstream * label_reader,size_t index)212 Status STL10Op::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) {
213 RETURN_UNEXPECTED_IF_NULL(image_reader);
214
215 Path image_path(image_names_[index]);
216 bool has_label_file = image_path.Basename().find("unlabeled") == std::string::npos;
217
218 std::streamsize image_size = image_reader->tellg();
219
220 image_reader->seekg(0, std::ios::beg);
221 auto images_buf = std::make_unique<char[]>(image_size);
222 auto labels_buf = std::make_unique<char[]>(0);
223
224 if (images_buf == nullptr) {
225 std::string err_msg = "Failed to allocate memory for STL10 buffer.";
226 MS_LOG(ERROR) << err_msg.c_str();
227 RETURN_STATUS_UNEXPECTED(err_msg);
228 }
229
230 uint64_t num_images = static_cast<uint64_t>(image_size / kSTLImageSize);
231 (void)image_reader->read(images_buf.get(), image_size);
232 if (image_reader->fail()) {
233 RETURN_STATUS_UNEXPECTED("Invalid file, failed to read image: " + image_names_[index] +
234 ", size:" + std::to_string(kSTLImageSize * num_images));
235 }
236
237 if (has_label_file) {
238 RETURN_UNEXPECTED_IF_NULL(label_reader);
239 std::streamsize label_size = label_reader->tellg();
240 if (static_cast<uint64_t>(label_size) != num_images) {
241 RETURN_STATUS_UNEXPECTED("Invalid file, error in " + label_names_[index] +
242 ": the number of labels is not equal to the number of images in " + image_names_[index] +
243 "! Please check the file integrity!");
244 }
245
246 label_reader->seekg(0, std::ios::beg);
247 labels_buf = std::make_unique<char[]>(label_size);
248 if (labels_buf == nullptr) {
249 std::string err_msg = "Failed to allocate memory for STL10 buffer.";
250 MS_LOG(ERROR) << err_msg.c_str();
251 RETURN_STATUS_UNEXPECTED(err_msg);
252 }
253
254 (void)label_reader->read(labels_buf.get(), label_size);
255 if (label_reader->fail()) {
256 RETURN_STATUS_UNEXPECTED("Invalid file, failed to read label:" + label_names_[index] +
257 ", size: " + std::to_string(num_images));
258 }
259 }
260
261 for (int64_t j = 0; j < num_images; ++j) {
262 int32_t label = (has_label_file ? labels_buf[j] - 1 : -1);
263
264 std::shared_ptr<Tensor> image_tensor;
265 RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({kSTLImageRows, kSTLImageCols, kSTLImageChannel}),
266 data_schema_->Column(0).Type(), &image_tensor));
267
268 auto iter = image_tensor->begin<uint8_t>();
269 uint64_t total_pix = kSTLImageRows * kSTLImageCols;
270 // stl10: Column major order.
271 for (uint64_t count = 0, pix = 0; count < total_pix; count++) {
272 if (count % kSTLImageRows == 0) {
273 pix = count / kSTLImageRows;
274 }
275
276 for (int ch = 0; ch < kSTLImageChannel; ch++) {
277 *iter = images_buf[j * kSTLImageSize + ch * total_pix + pix];
278 iter++;
279 }
280 pix += kSTLImageRows;
281 }
282
283 (void)stl10_image_label_pairs_.emplace_back(std::make_pair(image_tensor, label));
284 image_path_.push_back(image_names_[index]);
285 label_path_.push_back(has_label_file ? label_names_[index] : "no label");
286 }
287
288 return Status::OK();
289 }
290
PrepareData()291 Status STL10Op::PrepareData() {
292 RETURN_IF_NOT_OK(this->WalkAllFiles());
293 RETURN_IF_NOT_OK(this->ParseSTLData()); // Parse stl10 data and get num rows, blocking.
294
295 return Status::OK();
296 }
297
CountTotalRows(const std::string & dir,const std::string & usage,int64_t * count)298 Status STL10Op::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
299 RETURN_UNEXPECTED_IF_NULL(count);
300 // the logic of counting the number of samples is copied from ParseSTLData().
301 const int64_t num_samples = 0;
302 const int64_t start_index = 0;
303 auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
304 auto schema = std::make_unique<DataSchema>();
305 RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
306 TensorShape scalar = TensorShape::CreateScalar();
307 RETURN_IF_NOT_OK(
308 schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
309 std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
310 int32_t num_workers = cfg->num_parallel_workers();
311 int32_t op_connect_size = cfg->op_connector_size();
312 auto op = std::make_shared<STL10Op>(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler));
313
314 RETURN_IF_NOT_OK(op->WalkAllFiles());
315
316 bool use_train = false;
317 bool use_test = false;
318 bool use_unlabeled = false;
319
320 if (usage == "train") {
321 use_train = true;
322 } else if (usage == "test") {
323 use_test = true;
324 } else if (usage == "unlabeled") {
325 use_unlabeled = true;
326 } else if (usage == "train+unlabeled") {
327 use_train = true;
328 use_unlabeled = true;
329 } else if (usage == "all") {
330 use_train = true;
331 use_test = true;
332 use_unlabeled = true;
333 } else {
334 RETURN_STATUS_UNEXPECTED(
335 "Invalid parameter, usage should be \"train\", \"test\", \"unlabeled\", "
336 "\"train+unlabeled\", \"all\", got " +
337 usage);
338 }
339
340 *count = 0;
341 uint64_t num_stl10_records = 0;
342 uint64_t total_image_size = 0;
343
344 if (use_train) {
345 uint32_t index = (usage == "all" ? 1 : 0);
346 Path train_image_path(op->image_names_[index]);
347 CHECK_FAIL_RETURN_UNEXPECTED(train_image_path.Exists() && !train_image_path.IsDirectory(),
348 "Invalid file, failed to open stl10 file: " + train_image_path.ToString());
349
350 std::ifstream train_image_file(train_image_path.ToString(), std::ios::binary | std::ios::ate);
351 CHECK_FAIL_RETURN_UNEXPECTED(train_image_file.is_open(),
352 "Invalid file, failed to open stl10 file: " + train_image_path.ToString());
353 total_image_size += static_cast<uint64_t>(train_image_file.tellg());
354
355 train_image_file.close();
356 }
357
358 if (use_test) {
359 uint32_t index = 0;
360 Path test_image_path(op->image_names_[index]);
361 CHECK_FAIL_RETURN_UNEXPECTED(test_image_path.Exists() && !test_image_path.IsDirectory(),
362 "Invalid file, failed to open stl10 file: " + test_image_path.ToString());
363
364 std::ifstream test_image_file(test_image_path.ToString(), std::ios::binary | std::ios::ate);
365 CHECK_FAIL_RETURN_UNEXPECTED(test_image_file.is_open(),
366 "Invalid file, failed to open stl10 file: " + test_image_path.ToString());
367 total_image_size += static_cast<uint64_t>(test_image_file.tellg());
368
369 test_image_file.close();
370 }
371
372 if (use_unlabeled) {
373 uint32_t index = (usage == "unlabeled" ? 0 : (usage == "train+unlabeled" ? 1 : 2));
374 Path unlabeled_image_path(op->image_names_[index]);
375 CHECK_FAIL_RETURN_UNEXPECTED(unlabeled_image_path.Exists() && !unlabeled_image_path.IsDirectory(),
376 "Invalid file, failed to open stl10 file: " + unlabeled_image_path.ToString());
377
378 std::ifstream unlabeled_image_file(unlabeled_image_path.ToString(), std::ios::binary | std::ios::ate);
379 CHECK_FAIL_RETURN_UNEXPECTED(unlabeled_image_file.is_open(),
380 "Invalid file, failed to open stl10 file: " + unlabeled_image_path.ToString());
381 total_image_size += static_cast<uint64_t>(unlabeled_image_file.tellg());
382
383 unlabeled_image_file.close();
384 }
385
386 num_stl10_records = static_cast<uint64_t>(total_image_size / kSTLImageSize);
387
388 *count = *count + num_stl10_records;
389
390 return Status::OK();
391 }
392
ComputeColMap()393 Status STL10Op::ComputeColMap() {
394 // set the column Name map (base class field).
395 if (column_name_id_map_.empty()) {
396 for (uint32_t i = 0; i < data_schema_->NumColumns(); ++i) {
397 column_name_id_map_[data_schema_->Column(i).Name()] = i;
398 }
399 } else {
400 MS_LOG(WARNING) << "Column Name map is already set!";
401 }
402 return Status::OK();
403 }
404 } // namespace dataset
405 } // namespace mindspore
406