• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/spectrogram_test_utils.h"
17 
18 #include <math.h>
19 #include <stddef.h>
20 
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/lib/io/path.h"
23 #include "tensorflow/core/lib/strings/numbers.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/wav/wav_io.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
ReadWaveFileToVector(const string & file_name,std::vector<double> * data)32 bool ReadWaveFileToVector(const string& file_name, std::vector<double>* data) {
33   string wav_data;
34   if (!ReadFileToString(Env::Default(), file_name, &wav_data).ok()) {
35     LOG(ERROR) << "Wave file read failed for " << file_name;
36     return false;
37   }
38   std::vector<float> decoded_data;
39   uint32 decoded_sample_count;
40   uint16 decoded_channel_count;
41   uint32 decoded_sample_rate;
42   if (!wav::DecodeLin16WaveAsFloatVector(
43            wav_data, &decoded_data, &decoded_sample_count,
44            &decoded_channel_count, &decoded_sample_rate)
45            .ok()) {
46     return false;
47   }
48   // Convert from float to double for the output value.
49   data->resize(decoded_data.size());
50   for (int i = 0; i < decoded_data.size(); ++i) {
51     (*data)[i] = decoded_data[i];
52   }
53   return true;
54 }
55 
ReadRawFloatFileToComplexVector(const string & file_name,int row_length,std::vector<std::vector<std::complex<double>>> * data)56 bool ReadRawFloatFileToComplexVector(
57     const string& file_name, int row_length,
58     std::vector<std::vector<std::complex<double> > >* data) {
59   data->clear();
60   string data_string;
61   if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
62     LOG(ERROR) << "Failed to open file " << file_name;
63     return false;
64   }
65   float real_out;
66   float imag_out;
67   const int kBytesPerValue = 4;
68   CHECK_EQ(sizeof(real_out), kBytesPerValue);
69   std::vector<std::complex<double> > data_row;
70   int row_counter = 0;
71   int offset = 0;
72   const int end = data_string.size();
73   while (offset < end) {
74 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
75     char arr[4];
76     for (int i = 0; i < kBytesPerValue; ++i) {
77       arr[3 - i] = *(data_string.data() + offset + i);
78     }
79     memcpy(&real_out, arr, kBytesPerValue);
80     offset += kBytesPerValue;
81     for (int i = 0; i < kBytesPerValue; ++i) {
82       arr[3 - i] = *(data_string.data() + offset + i);
83     }
84     memcpy(&imag_out, arr, kBytesPerValue);
85     offset += kBytesPerValue;
86 #else
87     memcpy(&real_out, data_string.data() + offset, kBytesPerValue);
88     offset += kBytesPerValue;
89     memcpy(&imag_out, data_string.data() + offset, kBytesPerValue);
90     offset += kBytesPerValue;
91 #endif
92     if (row_counter >= row_length) {
93       data->push_back(data_row);
94       data_row.clear();
95       row_counter = 0;
96     }
97     data_row.push_back(std::complex<double>(real_out, imag_out));
98     ++row_counter;
99   }
100   if (row_counter >= row_length) {
101     data->push_back(data_row);
102   }
103   return true;
104 }
105 
ReadCSVFileToComplexVectorOrDie(const string & file_name,std::vector<std::vector<std::complex<double>>> * data)106 void ReadCSVFileToComplexVectorOrDie(
107     const string& file_name,
108     std::vector<std::vector<std::complex<double> > >* data) {
109   data->clear();
110   string data_string;
111   if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
112     LOG(FATAL) << "Failed to open file " << file_name;
113     return;
114   }
115   std::vector<string> lines = str_util::Split(data_string, '\n');
116   for (const string& line : lines) {
117     if (line.empty()) {
118       continue;
119     }
120     std::vector<std::complex<double> > data_line;
121     std::vector<string> values = str_util::Split(line, ',');
122     for (std::vector<string>::const_iterator i = values.begin();
123          i != values.end(); ++i) {
124       // each element of values may be in the form:
125       // 0.001+0.002i, 0.001, 0.001i, -1.2i, -1.2-3.2i, 1.5, 1.5e-03+21.0i
126       std::vector<string> parts;
127       // Find the first instance of + or - after the second character
128       // in the string, that does not immediately follow an 'e'.
129       size_t operator_index = i->find_first_of("+-", 2);
130       if (operator_index < i->size() &&
131           i->substr(operator_index - 1, 1) == "e") {
132         operator_index = i->find_first_of("+-", operator_index + 1);
133       }
134       parts.push_back(i->substr(0, operator_index));
135       if (operator_index < i->size()) {
136         parts.push_back(i->substr(operator_index, string::npos));
137       }
138 
139       double real_part = 0.0;
140       double imaginary_part = 0.0;
141       for (std::vector<string>::const_iterator j = parts.begin();
142            j != parts.end(); ++j) {
143         if (j->find_first_of("ij") != string::npos) {
144           strings::safe_strtod(*j, &imaginary_part);
145         } else {
146           strings::safe_strtod(*j, &real_part);
147         }
148       }
149       data_line.push_back(std::complex<double>(real_part, imaginary_part));
150     }
151     data->push_back(data_line);
152   }
153 }
154 
ReadCSVFileToArrayOrDie(const string & filename,std::vector<std::vector<float>> * array)155 void ReadCSVFileToArrayOrDie(const string& filename,
156                              std::vector<std::vector<float> >* array) {
157   string contents;
158   TF_CHECK_OK(ReadFileToString(Env::Default(), filename, &contents));
159   std::vector<string> lines = str_util::Split(contents, '\n');
160   contents.clear();
161 
162   array->clear();
163   std::vector<float> values;
164   for (int l = 0; l < lines.size(); ++l) {
165     values.clear();
166     std::vector<string> split_line = str_util::Split(lines[l], ",");
167     for (const string& token : split_line) {
168       float tmp;
169       CHECK(strings::safe_strtof(token, &tmp));
170       values.push_back(tmp);
171     }
172     array->push_back(values);
173   }
174 }
175 
WriteDoubleVectorToFile(const string & file_name,const std::vector<double> & data)176 bool WriteDoubleVectorToFile(const string& file_name,
177                              const std::vector<double>& data) {
178   std::unique_ptr<WritableFile> file;
179   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
180     LOG(ERROR) << "Failed to open file " << file_name;
181     return false;
182   }
183   for (int i = 0; i < data.size(); ++i) {
184     if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
185                                   sizeof(data[i])))
186              .ok()) {
187       LOG(ERROR) << "Failed to append to file " << file_name;
188       return false;
189     }
190   }
191   if (!file->Close().ok()) {
192     LOG(ERROR) << "Failed to close file " << file_name;
193     return false;
194   }
195   return true;
196 }
197 
WriteFloatVectorToFile(const string & file_name,const std::vector<float> & data)198 bool WriteFloatVectorToFile(const string& file_name,
199                             const std::vector<float>& data) {
200   std::unique_ptr<WritableFile> file;
201   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
202     LOG(ERROR) << "Failed to open file " << file_name;
203     return false;
204   }
205   for (int i = 0; i < data.size(); ++i) {
206     if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
207                                   sizeof(data[i])))
208              .ok()) {
209       LOG(ERROR) << "Failed to append to file " << file_name;
210       return false;
211     }
212   }
213   if (!file->Close().ok()) {
214     LOG(ERROR) << "Failed to close file " << file_name;
215     return false;
216   }
217   return true;
218 }
219 
WriteDoubleArrayToFile(const string & file_name,int size,const double * data)220 bool WriteDoubleArrayToFile(const string& file_name, int size,
221                             const double* data) {
222   std::unique_ptr<WritableFile> file;
223   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
224     LOG(ERROR) << "Failed to open file " << file_name;
225     return false;
226   }
227   for (int i = 0; i < size; ++i) {
228     if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
229                                   sizeof(data[i])))
230              .ok()) {
231       LOG(ERROR) << "Failed to append to file " << file_name;
232       return false;
233     }
234   }
235   if (!file->Close().ok()) {
236     LOG(ERROR) << "Failed to close file " << file_name;
237     return false;
238   }
239   return true;
240 }
241 
WriteFloatArrayToFile(const string & file_name,int size,const float * data)242 bool WriteFloatArrayToFile(const string& file_name, int size,
243                            const float* data) {
244   std::unique_ptr<WritableFile> file;
245   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
246     LOG(ERROR) << "Failed to open file " << file_name;
247     return false;
248   }
249   for (int i = 0; i < size; ++i) {
250     if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
251                                   sizeof(data[i])))
252              .ok()) {
253       LOG(ERROR) << "Failed to append to file " << file_name;
254       return false;
255     }
256   }
257   if (!file->Close().ok()) {
258     LOG(ERROR) << "Failed to close file " << file_name;
259     return false;
260   }
261   return true;
262 }
263 
WriteComplexVectorToRawFloatFile(const string & file_name,const std::vector<std::vector<std::complex<double>>> & data)264 bool WriteComplexVectorToRawFloatFile(
265     const string& file_name,
266     const std::vector<std::vector<std::complex<double> > >& data) {
267   std::unique_ptr<WritableFile> file;
268   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
269     LOG(ERROR) << "Failed to open file " << file_name;
270     return false;
271   }
272   for (int i = 0; i < data.size(); ++i) {
273     for (int j = 0; j < data[i].size(); ++j) {
274       const float real_part(real(data[i][j]));
275       if (!file->Append(StringPiece(reinterpret_cast<const char*>(&real_part),
276                                     sizeof(real_part)))
277                .ok()) {
278         LOG(ERROR) << "Failed to append to file " << file_name;
279         return false;
280       }
281 
282       const float imag_part(imag(data[i][j]));
283       if (!file->Append(StringPiece(reinterpret_cast<const char*>(&imag_part),
284                                     sizeof(imag_part)))
285                .ok()) {
286         LOG(ERROR) << "Failed to append to file " << file_name;
287         return false;
288       }
289     }
290   }
291   if (!file->Close().ok()) {
292     LOG(ERROR) << "Failed to close file " << file_name;
293     return false;
294   }
295   return true;
296 }
297 
SineWave(int sample_rate,float frequency,float duration_seconds,std::vector<double> * data)298 void SineWave(int sample_rate, float frequency, float duration_seconds,
299               std::vector<double>* data) {
300   data->clear();
301   for (int i = 0; i < static_cast<int>(sample_rate * duration_seconds); ++i) {
302     data->push_back(
303         sin(2.0 * M_PI * i * frequency / static_cast<double>(sample_rate)));
304   }
305 }
306 
307 }  // namespace tensorflow
308