• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Functions to write audio in WAV format.
17 
18 #include <math.h>
19 #include <string.h>
20 #include <algorithm>
21 
22 #include "absl/base/casts.h"
23 #include "tensorflow/core/lib/core/coding.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/wav/wav_io.h"
26 #include "tensorflow/core/platform/byte_order.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 
30 namespace tensorflow {
31 namespace wav {
32 namespace {
33 
34 struct TF_PACKED RiffChunk {
35   char chunk_id[4];
36   char chunk_data_size[4];
37   char riff_type[4];
38 };
39 static_assert(sizeof(RiffChunk) == 12, "TF_PACKED does not work.");
40 
41 struct TF_PACKED FormatChunk {
42   char chunk_id[4];
43   char chunk_data_size[4];
44   char compression_code[2];
45   char channel_numbers[2];
46   char sample_rate[4];
47   char bytes_per_second[4];
48   char bytes_per_frame[2];
49   char bits_per_sample[2];
50 };
51 static_assert(sizeof(FormatChunk) == 24, "TF_PACKED does not work.");
52 
53 struct TF_PACKED DataChunk {
54   char chunk_id[4];
55   char chunk_data_size[4];
56 };
57 static_assert(sizeof(DataChunk) == 8, "TF_PACKED does not work.");
58 
59 struct TF_PACKED WavHeader {
60   RiffChunk riff_chunk;
61   FormatChunk format_chunk;
62   DataChunk data_chunk;
63 };
64 static_assert(sizeof(WavHeader) ==
65                   sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk),
66               "TF_PACKED does not work.");
67 
68 constexpr char kRiffChunkId[] = "RIFF";
69 constexpr char kRiffType[] = "WAVE";
70 constexpr char kFormatChunkId[] = "fmt ";
71 constexpr char kDataChunkId[] = "data";
72 
FloatToInt16Sample(float data)73 inline int16 FloatToInt16Sample(float data) {
74   constexpr float kMultiplier = 1.0f * (1 << 15);
75   return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min),
76                          kint16max);
77 }
78 
Int16SampleToFloat(int16_t data)79 inline float Int16SampleToFloat(int16_t data) {
80   constexpr float kMultiplier = 1.0f / (1 << 15);
81   return data * kMultiplier;
82 }
83 
84 }  // namespace
85 
86 // Handles moving the data index forward, validating the arguments, and avoiding
87 // overflow or underflow.
IncrementOffset(int old_offset,int64_t increment,size_t max_size,int * new_offset)88 Status IncrementOffset(int old_offset, int64_t increment, size_t max_size,
89                        int* new_offset) {
90   if (old_offset < 0) {
91     return errors::InvalidArgument("Negative offsets are not allowed: ",
92                                    old_offset);
93   }
94   if (increment < 0) {
95     return errors::InvalidArgument("Negative increment is not allowed: ",
96                                    increment);
97   }
98   if (old_offset > max_size) {
99     return errors::InvalidArgument("Initial offset is outside data range: ",
100                                    old_offset);
101   }
102   *new_offset = old_offset + increment;
103   if (*new_offset > max_size) {
104     return errors::InvalidArgument("Data too short when trying to read string");
105   }
106   // See above for the check that the input offset is positive. If it's negative
107   // here then it means that there's been an overflow in the arithmetic.
108   if (*new_offset < 0) {
109     return errors::InvalidArgument("Offset too large, overflowed: ",
110                                    *new_offset);
111   }
112   return OkStatus();
113 }
114 
ExpectText(const std::string & data,const std::string & expected_text,int * offset)115 Status ExpectText(const std::string& data, const std::string& expected_text,
116                   int* offset) {
117   int new_offset;
118   TF_RETURN_IF_ERROR(
119       IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
120   const std::string found_text(data.begin() + *offset,
121                                data.begin() + new_offset);
122   if (found_text != expected_text) {
123     return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
124                                    " but found ", found_text);
125   }
126   *offset = new_offset;
127   return OkStatus();
128 }
129 
ReadString(const std::string & data,int expected_length,std::string * value,int * offset)130 Status ReadString(const std::string& data, int expected_length,
131                   std::string* value, int* offset) {
132   int new_offset;
133   TF_RETURN_IF_ERROR(
134       IncrementOffset(*offset, expected_length, data.size(), &new_offset));
135   *value = std::string(data.begin() + *offset, data.begin() + new_offset);
136   *offset = new_offset;
137   return OkStatus();
138 }
139 
140 template <typename T>
EncodeAudioAsS16LEWav(const float * audio,size_t sample_rate,size_t num_channels,size_t num_frames,T * wav_string)141 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
142                              size_t num_channels, size_t num_frames,
143                              T* wav_string) {
144   constexpr size_t kFormatChunkSize = 16;
145   constexpr size_t kCompressionCodePcm = 1;
146   constexpr size_t kBitsPerSample = 16;
147   constexpr size_t kBytesPerSample = kBitsPerSample / 8;
148   constexpr size_t kHeaderSize = sizeof(WavHeader);
149 
150   // If num_frames is zero, audio can be nullptr.
151   if (audio == nullptr && num_frames > 0) {
152     return errors::InvalidArgument("audio is null");
153   }
154   if (wav_string == nullptr) {
155     return errors::InvalidArgument("wav_string is null");
156   }
157   if (sample_rate == 0 || sample_rate > kuint32max) {
158     return errors::InvalidArgument("sample_rate must be in (0, 2^32), got: ",
159                                    sample_rate);
160   }
161   if (num_channels == 0 || num_channels > kuint16max) {
162     return errors::InvalidArgument("num_channels must be in (0, 2^16), got: ",
163                                    num_channels);
164   }
165 
166   const size_t bytes_per_second = sample_rate * kBytesPerSample * num_channels;
167   const size_t num_samples = num_frames * num_channels;
168   const size_t data_size = num_samples * kBytesPerSample;
169   const size_t file_size = kHeaderSize + num_samples * kBytesPerSample;
170   const size_t bytes_per_frame = kBytesPerSample * num_channels;
171 
172   // WAV represents the length of the file as a uint32 so file_size cannot
173   // exceed kuint32max.
174   if (file_size > kuint32max) {
175     return errors::InvalidArgument(
176         "Provided channels and frames cannot be encoded as a WAV.");
177   }
178 
179   wav_string->resize(file_size);
180   char* data = &(*wav_string)[0];
181   WavHeader* header = absl::bit_cast<WavHeader*>(data);
182 
183   // Fill RIFF chunk.
184   auto* riff_chunk = &header->riff_chunk;
185   memcpy(riff_chunk->chunk_id, kRiffChunkId, 4);
186   core::EncodeFixed32(riff_chunk->chunk_data_size, file_size - 8);
187   memcpy(riff_chunk->riff_type, kRiffType, 4);
188 
189   // Fill format chunk.
190   auto* format_chunk = &header->format_chunk;
191   memcpy(format_chunk->chunk_id, kFormatChunkId, 4);
192   core::EncodeFixed32(format_chunk->chunk_data_size, kFormatChunkSize);
193   core::EncodeFixed16(format_chunk->compression_code, kCompressionCodePcm);
194   core::EncodeFixed16(format_chunk->channel_numbers, num_channels);
195   core::EncodeFixed32(format_chunk->sample_rate, sample_rate);
196   core::EncodeFixed32(format_chunk->bytes_per_second, bytes_per_second);
197   core::EncodeFixed16(format_chunk->bytes_per_frame, bytes_per_frame);
198   core::EncodeFixed16(format_chunk->bits_per_sample, kBitsPerSample);
199 
200   // Fill data chunk.
201   auto* data_chunk = &header->data_chunk;
202   memcpy(data_chunk->chunk_id, kDataChunkId, 4);
203   core::EncodeFixed32(data_chunk->chunk_data_size, data_size);
204 
205   // Write the audio.
206   data += kHeaderSize;
207   for (size_t i = 0; i < num_samples; ++i) {
208     int16_t sample = FloatToInt16Sample(audio[i]);
209     core::EncodeFixed16(&data[i * kBytesPerSample],
210                         static_cast<uint16>(sample));
211   }
212   return OkStatus();
213 }
214 
215 template Status EncodeAudioAsS16LEWav<std::string>(const float* audio,
216                                                    size_t sample_rate,
217                                                    size_t num_channels,
218                                                    size_t num_frames,
219                                                    std::string* wav_string);
220 template Status EncodeAudioAsS16LEWav<tstring>(const float* audio,
221                                                size_t sample_rate,
222                                                size_t num_channels,
223                                                size_t num_frames,
224                                                tstring* wav_string);
225 
DecodeLin16WaveAsFloatVector(const std::string & wav_string,std::vector<float> * float_values,uint32 * sample_count,uint16 * channel_count,uint32 * sample_rate)226 Status DecodeLin16WaveAsFloatVector(const std::string& wav_string,
227                                     std::vector<float>* float_values,
228                                     uint32* sample_count, uint16* channel_count,
229                                     uint32* sample_rate) {
230   int offset = 0;
231   TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
232   uint32 total_file_size;
233   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset));
234   TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
235   std::string found_text;
236   TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &found_text, &offset));
237   while (found_text != kFormatChunkId) {
238     // Padding chunk may occur between "WAVE" and "fmt ".
239     // Skip JUNK/bext/etc field to support for WAV file with either JUNK Chunk,
240     // or broadcast WAV where additional tags might appear.
241     // Reference: the implementation of tfio in audio_video_wav_kernels.cc,
242     //            https://www.daubnet.com/en/file-format-riff,
243     //            https://en.wikipedia.org/wiki/Broadcast_Wave_Format
244     if (found_text != "JUNK" && found_text != "bext" && found_text != "iXML" &&
245         found_text != "qlty" && found_text != "mext" && found_text != "levl" &&
246         found_text != "link" && found_text != "axml") {
247       return errors::InvalidArgument("Unexpected field ", found_text);
248     }
249     uint32 size_of_chunk;
250     TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &size_of_chunk, &offset));
251     TF_RETURN_IF_ERROR(
252         IncrementOffset(offset, size_of_chunk, wav_string.size(), &offset));
253     TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &found_text, &offset));
254   }
255   uint32 format_chunk_size;
256   TF_RETURN_IF_ERROR(
257       ReadValue<uint32>(wav_string, &format_chunk_size, &offset));
258   if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
259     return errors::InvalidArgument(
260         "Bad format chunk size for WAV: Expected 16 or 18, but got",
261         format_chunk_size);
262   }
263   uint16 audio_format;
264   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset));
265   if (audio_format != 1) {
266     return errors::InvalidArgument(
267         "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
268   }
269   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
270   if (*channel_count < 1) {
271     return errors::InvalidArgument(
272         "Bad number of channels for WAV: Expected at least 1, but got ",
273         *channel_count);
274   }
275   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
276   uint32 bytes_per_second;
277   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
278   uint16 bytes_per_sample;
279   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset));
280   // Confusingly, bits per sample is defined as holding the number of bits for
281   // one channel, unlike the definition of sample used elsewhere in the WAV
282   // spec. For example, bytes per sample is the memory needed for all channels
283   // for one point in time.
284   uint16 bits_per_sample;
285   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset));
286   if (bits_per_sample != 16) {
287     return errors::InvalidArgument(
288         "Can only read 16-bit WAV files, but received ", bits_per_sample);
289   }
290   const uint32 expected_bytes_per_sample =
291       ((bits_per_sample * *channel_count) + 7) / 8;
292   if (bytes_per_sample != expected_bytes_per_sample) {
293     return errors::InvalidArgument(
294         "Bad bytes per sample in WAV header: Expected ",
295         expected_bytes_per_sample, " but got ", bytes_per_sample);
296   }
297   const uint32 expected_bytes_per_second = bytes_per_sample * *sample_rate;
298   if (bytes_per_second != expected_bytes_per_second) {
299     return errors::InvalidArgument(
300         "Bad bytes per second in WAV header: Expected ",
301         expected_bytes_per_second, " but got ", bytes_per_second,
302         " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample,
303         ")");
304   }
305   if (format_chunk_size == 18) {
306     // Skip over this unused section.
307     offset += 2;
308   }
309 
310   bool was_data_found = false;
311   while (offset < wav_string.size()) {
312     std::string chunk_id;
313     TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
314     uint32 chunk_size;
315     TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
316     if (chunk_size > std::numeric_limits<int32>::max()) {
317       return errors::InvalidArgument(
318           "WAV data chunk '", chunk_id, "' is too large: ", chunk_size,
319           " bytes, but the limit is ", std::numeric_limits<int32>::max());
320     }
321     if (chunk_id == kDataChunkId) {
322       if (was_data_found) {
323         return errors::InvalidArgument("More than one data chunk found in WAV");
324       }
325       was_data_found = true;
326       *sample_count = chunk_size / bytes_per_sample;
327       const uint32 data_count = *sample_count * *channel_count;
328       int unused_new_offset = 0;
329       // Validate that the data exists before allocating space for it
330       // (prevent easy OOM errors).
331       TF_RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16) * data_count,
332                                          wav_string.size(),
333                                          &unused_new_offset));
334       float_values->resize(data_count);
335       for (int i = 0; i < data_count; ++i) {
336         int16_t single_channel_value = 0;
337         TF_RETURN_IF_ERROR(
338             ReadValue<int16>(wav_string, &single_channel_value, &offset));
339         (*float_values)[i] = Int16SampleToFloat(single_channel_value);
340       }
341     } else {
342       offset += chunk_size;
343     }
344   }
345   if (!was_data_found) {
346     return errors::InvalidArgument("No data chunk found in WAV");
347   }
348   return OkStatus();
349 }
350 
351 }  // namespace wav
352 }  // namespace tensorflow
353