• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/speech_commands_op.h"
17 
18 #include <fstream>
19 #include <iomanip>
20 #include <regex>
21 
22 #include "minddata/dataset/audio/kernels/audio_utils.h"
23 #include "minddata/dataset/core/config_manager.h"
24 #include "minddata/dataset/core/tensor_shape.h"
25 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
26 #include "minddata/dataset/engine/execution_tree.h"
27 #include "utils/file_utils.h"
28 #include "utils/ms_utils.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 constexpr char kTestFiles[] = "testing_list.txt";
33 constexpr char kValFiles[] = "validation_list.txt";
34 constexpr char kExtension[] = ".wav";
35 #ifndef _WIN32
36 constexpr char kSplitSymbol[] = "/";
37 #else
38 constexpr char kSplitSymbol[] = "\\";
39 #endif
40 
SpeechCommandsOp(const std::string & dataset_dir,const std::string & usage,int32_t num_workers,int32_t queue_size,std::unique_ptr<DataSchema> data_schema,std::shared_ptr<SamplerRT> sampler)41 SpeechCommandsOp::SpeechCommandsOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers,
42                                    int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
43                                    std::shared_ptr<SamplerRT> sampler)
44     : MappableLeafOp(num_workers, queue_size, std::move(sampler)),
45       dataset_dir_(dataset_dir),
46       usage_(usage),
47       data_schema_(std::move(data_schema)) {}
48 
PrepareData()49 Status SpeechCommandsOp::PrepareData() {
50   // Get file lists according to usage.
51   // When usage == "train", need to get all filenames then subtract files of usage: "test" and "valid".
52   std::set<std::string> selected_files;
53   auto real_dataset_dir = FileUtils::GetRealPath(dataset_dir_.c_str());
54   if (!real_dataset_dir.has_value()) {
55     MS_LOG(ERROR) << "Get real path failed, path=" << dataset_dir_;
56     RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + dataset_dir_);
57   }
58   std::string real_path = real_dataset_dir.value();
59   if (usage_ == "all") {
60     RETURN_IF_NOT_OK(WalkAllFiles(real_path));
61     selected_files = all_wave_files;
62   } else if (usage_ == "test" || usage_ == "valid") {
63     RETURN_IF_NOT_OK(ParseFileList(real_path, usage_));
64     selected_files = loaded_names;
65   } else {
66     RETURN_IF_NOT_OK(WalkAllFiles(real_path));
67     RETURN_IF_NOT_OK(ParseFileList(real_path, "test"));
68     RETURN_IF_NOT_OK(ParseFileList(real_path, "valid"));
69     set_difference(all_wave_files.begin(), all_wave_files.end(), loaded_names.begin(), loaded_names.end(),
70                    inserter(selected_files, selected_files.begin()));
71   }
72   selected_files_vec.assign(selected_files.begin(), selected_files.end());
73   num_rows_ = selected_files_vec.size();
74   return Status::OK();
75 }
76 
ParseFileList(const std::string & pf_path,const std::string & pf_usage)77 Status SpeechCommandsOp::ParseFileList(const std::string &pf_path, const std::string &pf_usage) {
78   std::string line;
79   std::string file_list = (pf_usage == "test" ? kTestFiles : kValFiles);
80   Path path(pf_path);
81   std::string list_path = (Path(pf_path) / Path(file_list)).ToString();
82   std::ifstream file_reader(list_path, std::ios::in);
83   while (getline(file_reader, line)) {
84     Path file_path(path / line);
85     loaded_names.insert(file_path.ToString());
86   }
87   file_reader.close();
88   return Status::OK();
89 }
90 
WalkAllFiles(const std::string & walk_path)91 Status SpeechCommandsOp::WalkAllFiles(const std::string &walk_path) {
92   Path dir(walk_path);
93   if (dir.IsDirectory() == false) {
94     RETURN_STATUS_UNEXPECTED("Invalid parameter, no folder found in: " + walk_path);
95   }
96   std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir);
97   RETURN_UNEXPECTED_IF_NULL(dir_itr);
98   std::vector<std::string> folder_names;
99   while (dir_itr->HasNext()) {
100     Path sub_dir = dir_itr->Next();
101     if (sub_dir.IsDirectory() && (sub_dir.ToString().find("_background_noise_") == std::string::npos)) {
102       folder_names.emplace_back(sub_dir.ToString());
103     }
104   }
105   CHECK_FAIL_RETURN_UNEXPECTED(!folder_names.empty(), "Invalid file, failed to open directory: " + dir.ToString());
106   for (int i = 0; i < folder_names.size(); i++) {
107     Path folder_path(folder_names[i]);
108     if (folder_path.IsDirectory()) {
109       auto folder_it = Path::DirIterator::OpenDirectory(&folder_path);
110       CHECK_FAIL_RETURN_UNEXPECTED(folder_it != nullptr, "Invalid path, failed to open dir: " + folder_path.ToString() +
111                                                            ", not exists or permission denied.");
112       while (folder_it->HasNext()) {
113         Path file = folder_it->Next();
114         if (file.Extension() == kExtension) {
115           all_wave_files.insert(file.ToString());
116         }
117       }
118     } else {
119       RETURN_STATUS_UNEXPECTED("Invalid file, failed to open directory: " + folder_path.ToString());
120     }
121   }
122   CHECK_FAIL_RETURN_UNEXPECTED(!all_wave_files.empty(), "Invalid file, no .wav files found under " + dataset_dir_);
123   return Status::OK();
124 }
125 
LoadTensorRow(row_id_type row_id,TensorRow * trow)126 Status SpeechCommandsOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
127   RETURN_UNEXPECTED_IF_NULL(trow);
128   std::string file_name = selected_files_vec[row_id];
129   std::shared_ptr<Tensor> waveform, sample_rate_scalar, label_scalar, speaker_id_scalar, utterance_number_scalar;
130   std::string label, speaker_id;
131   int32_t utterance_number, sample_rate;
132   std::vector<float> waveform_vec;
133   RETURN_IF_NOT_OK(ReadWaveFile(file_name, &waveform_vec, &sample_rate));
134   RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, &waveform));
135   RETURN_IF_NOT_OK(waveform->ExpandDim(0));
136   RETURN_IF_NOT_OK(GetFileInfo(file_name, &label, &speaker_id, &utterance_number));
137   RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &sample_rate_scalar));
138   RETURN_IF_NOT_OK(Tensor::CreateScalar(label, &label_scalar));
139   RETURN_IF_NOT_OK(Tensor::CreateScalar(speaker_id, &speaker_id_scalar));
140   RETURN_IF_NOT_OK(Tensor::CreateScalar(utterance_number, &utterance_number_scalar));
141   (*trow) = TensorRow(row_id, {waveform, sample_rate_scalar, label_scalar, speaker_id_scalar, utterance_number_scalar});
142   trow->setPath({file_name, file_name, file_name, file_name, file_name});
143   return Status::OK();
144 }
145 
Print(std::ostream & out,bool show_all) const146 void SpeechCommandsOp::Print(std::ostream &out, bool show_all) const {
147   if (!show_all) {
148     // Call the super class for displaying and common 1-liner info
149     ParallelOp::Print(out, show_all);
150     // Then show and custom derived-internal 1-liner info for this op
151     out << "\n";
152   } else {
153     // Call the super class for displaying any common detailed info
154     ParallelOp::Print(out, show_all);
155     // Then show any custom derived-internal stuff
156     out << "\nNumber of rows: " << num_rows_ << "\nSpeechCommands directory: " << dataset_dir_ << "\n\n";
157   }
158 }
159 
ComputeColMap()160 Status SpeechCommandsOp::ComputeColMap() {
161   if (column_name_id_map_.empty()) {
162     for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
163       column_name_id_map_[data_schema_->Column(i).Name()] = i;
164     }
165   } else {
166     MS_LOG(WARNING) << "Column name map is already set!";
167   }
168   return Status::OK();
169 }
170 
GetFileInfo(const std::string & file_path,std::string * label,std::string * speaker_id,int32_t * utterance_number)171 Status SpeechCommandsOp::GetFileInfo(const std::string &file_path, std::string *label, std::string *speaker_id,
172                                      int32_t *utterance_number) {
173   // Using regex to get wave infos from filename.
174   RETURN_UNEXPECTED_IF_NULL(label);
175   RETURN_UNEXPECTED_IF_NULL(speaker_id);
176   RETURN_UNEXPECTED_IF_NULL(utterance_number);
177   int32_t split_index = 0;
178   split_index = file_path.find_last_of(kSplitSymbol);
179   std::string label_string = file_path.substr(0, split_index);
180   *label = label_string.substr(label_string.find_last_of(kSplitSymbol) + 1);  // plus "1" for index start from 0.
181   std::string filename = file_path.substr(split_index + 1);
182   std::smatch result;
183   {
184     std::unique_lock<std::mutex> _lock(mux_);
185     (void)regex_match(filename, result, std::regex("(.*)_nohash_(\\d+)\\.wav"));
186   }
187   CHECK_FAIL_RETURN_UNEXPECTED(!(result[0] == "" || result[1] == ""),
188                                "Invalid file name, failed to get file info: " + filename);
189   *speaker_id = result[1];
190   std::string utt_id = result[2];
191   *utterance_number = atoi(utt_id.c_str());
192   return Status::OK();
193 }
194 
CountTotalRows(int64_t * num_rows)195 Status SpeechCommandsOp::CountTotalRows(int64_t *num_rows) {
196   RETURN_UNEXPECTED_IF_NULL(num_rows);
197   if (all_wave_files.size() == 0) {
198     auto real_path = FileUtils::GetRealPath(dataset_dir_.c_str());
199     if (!real_path.has_value()) {
200       MS_LOG(ERROR) << "Get real path failed, path=" << dataset_dir_;
201       RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + dataset_dir_);
202     }
203     RETURN_IF_NOT_OK(WalkAllFiles(real_path.value()));
204   }
205   (*num_rows) = static_cast<int64_t>(all_wave_files.size());
206   return Status::OK();
207 }
208 }  // namespace dataset
209 }  // namespace mindspore
210