• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Functions to write audio in WAV format.
17 
18 #include <math.h>
19 #include <string.h>
20 #include <algorithm>
21 
22 #include "absl/base/casts.h"
23 #include "tensorflow/core/lib/core/coding.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/wav/wav_io.h"
26 #include "tensorflow/core/platform/byte_order.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 
30 namespace tensorflow {
31 namespace wav {
32 namespace {
33 
34 struct TF_PACKED RiffChunk {
35   char chunk_id[4];
36   char chunk_data_size[4];
37   char riff_type[4];
38 };
39 static_assert(sizeof(RiffChunk) == 12, "TF_PACKED does not work.");
40 
41 struct TF_PACKED FormatChunk {
42   char chunk_id[4];
43   char chunk_data_size[4];
44   char compression_code[2];
45   char channel_numbers[2];
46   char sample_rate[4];
47   char bytes_per_second[4];
48   char bytes_per_frame[2];
49   char bits_per_sample[2];
50 };
51 static_assert(sizeof(FormatChunk) == 24, "TF_PACKED does not work.");
52 
53 struct TF_PACKED DataChunk {
54   char chunk_id[4];
55   char chunk_data_size[4];
56 };
57 static_assert(sizeof(DataChunk) == 8, "TF_PACKED does not work.");
58 
59 struct TF_PACKED WavHeader {
60   RiffChunk riff_chunk;
61   FormatChunk format_chunk;
62   DataChunk data_chunk;
63 };
64 static_assert(sizeof(WavHeader) ==
65                   sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk),
66               "TF_PACKED does not work.");
67 
68 constexpr char kRiffChunkId[] = "RIFF";
69 constexpr char kRiffType[] = "WAVE";
70 constexpr char kFormatChunkId[] = "fmt ";
71 constexpr char kDataChunkId[] = "data";
72 
FloatToInt16Sample(float data)73 inline int16 FloatToInt16Sample(float data) {
74   constexpr float kMultiplier = 1.0f * (1 << 15);
75   return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min),
76                          kint16max);
77 }
78 
Int16SampleToFloat(int16_t data)79 inline float Int16SampleToFloat(int16_t data) {
80   constexpr float kMultiplier = 1.0f / (1 << 15);
81   return data * kMultiplier;
82 }
83 
84 }  // namespace
85 
86 // Handles moving the data index forward, validating the arguments, and avoiding
87 // overflow or underflow.
IncrementOffset(int old_offset,size_t increment,size_t max_size,int * new_offset)88 Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
89                        int* new_offset) {
90   if (old_offset < 0) {
91     return errors::InvalidArgument("Negative offsets are not allowed: ",
92                                    old_offset);
93   }
94   if (old_offset > max_size) {
95     return errors::InvalidArgument("Initial offset is outside data range: ",
96                                    old_offset);
97   }
98   *new_offset = old_offset + increment;
99   if (*new_offset > max_size) {
100     return errors::InvalidArgument("Data too short when trying to read string");
101   }
102   // See above for the check that the input offset is positive. If it's negative
103   // here then it means that there's been an overflow in the arithmetic.
104   if (*new_offset < 0) {
105     return errors::InvalidArgument("Offset too large, overflowed: ",
106                                    *new_offset);
107   }
108   return Status::OK();
109 }
110 
ExpectText(const std::string & data,const std::string & expected_text,int * offset)111 Status ExpectText(const std::string& data, const std::string& expected_text,
112                   int* offset) {
113   int new_offset;
114   TF_RETURN_IF_ERROR(
115       IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
116   const std::string found_text(data.begin() + *offset,
117                                data.begin() + new_offset);
118   if (found_text != expected_text) {
119     return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
120                                    " but found ", found_text);
121   }
122   *offset = new_offset;
123   return Status::OK();
124 }
125 
ReadString(const std::string & data,int expected_length,std::string * value,int * offset)126 Status ReadString(const std::string& data, int expected_length,
127                   std::string* value, int* offset) {
128   int new_offset;
129   TF_RETURN_IF_ERROR(
130       IncrementOffset(*offset, expected_length, data.size(), &new_offset));
131   *value = std::string(data.begin() + *offset, data.begin() + new_offset);
132   *offset = new_offset;
133   return Status::OK();
134 }
135 
136 template <typename T>
EncodeAudioAsS16LEWav(const float * audio,size_t sample_rate,size_t num_channels,size_t num_frames,T * wav_string)137 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
138                              size_t num_channels, size_t num_frames,
139                              T* wav_string) {
140   constexpr size_t kFormatChunkSize = 16;
141   constexpr size_t kCompressionCodePcm = 1;
142   constexpr size_t kBitsPerSample = 16;
143   constexpr size_t kBytesPerSample = kBitsPerSample / 8;
144   constexpr size_t kHeaderSize = sizeof(WavHeader);
145 
146   // If num_frames is zero, audio can be nullptr.
147   if (audio == nullptr && num_frames > 0) {
148     return errors::InvalidArgument("audio is null");
149   }
150   if (wav_string == nullptr) {
151     return errors::InvalidArgument("wav_string is null");
152   }
153   if (sample_rate == 0 || sample_rate > kuint32max) {
154     return errors::InvalidArgument("sample_rate must be in (0, 2^32), got: ",
155                                    sample_rate);
156   }
157   if (num_channels == 0 || num_channels > kuint16max) {
158     return errors::InvalidArgument("num_channels must be in (0, 2^16), got: ",
159                                    num_channels);
160   }
161 
162   const size_t bytes_per_second = sample_rate * kBytesPerSample * num_channels;
163   const size_t num_samples = num_frames * num_channels;
164   const size_t data_size = num_samples * kBytesPerSample;
165   const size_t file_size = kHeaderSize + num_samples * kBytesPerSample;
166   const size_t bytes_per_frame = kBytesPerSample * num_channels;
167 
168   // WAV represents the length of the file as a uint32 so file_size cannot
169   // exceed kuint32max.
170   if (file_size > kuint32max) {
171     return errors::InvalidArgument(
172         "Provided channels and frames cannot be encoded as a WAV.");
173   }
174 
175   wav_string->resize(file_size);
176   char* data = &(*wav_string)[0];
177   WavHeader* header = absl::bit_cast<WavHeader*>(data);
178 
179   // Fill RIFF chunk.
180   auto* riff_chunk = &header->riff_chunk;
181   memcpy(riff_chunk->chunk_id, kRiffChunkId, 4);
182   core::EncodeFixed32(riff_chunk->chunk_data_size, file_size - 8);
183   memcpy(riff_chunk->riff_type, kRiffType, 4);
184 
185   // Fill format chunk.
186   auto* format_chunk = &header->format_chunk;
187   memcpy(format_chunk->chunk_id, kFormatChunkId, 4);
188   core::EncodeFixed32(format_chunk->chunk_data_size, kFormatChunkSize);
189   core::EncodeFixed16(format_chunk->compression_code, kCompressionCodePcm);
190   core::EncodeFixed16(format_chunk->channel_numbers, num_channels);
191   core::EncodeFixed32(format_chunk->sample_rate, sample_rate);
192   core::EncodeFixed32(format_chunk->bytes_per_second, bytes_per_second);
193   core::EncodeFixed16(format_chunk->bytes_per_frame, bytes_per_frame);
194   core::EncodeFixed16(format_chunk->bits_per_sample, kBitsPerSample);
195 
196   // Fill data chunk.
197   auto* data_chunk = &header->data_chunk;
198   memcpy(data_chunk->chunk_id, kDataChunkId, 4);
199   core::EncodeFixed32(data_chunk->chunk_data_size, data_size);
200 
201   // Write the audio.
202   data += kHeaderSize;
203   for (size_t i = 0; i < num_samples; ++i) {
204     int16_t sample = FloatToInt16Sample(audio[i]);
205     core::EncodeFixed16(&data[i * kBytesPerSample],
206                         static_cast<uint16>(sample));
207   }
208   return Status::OK();
209 }
210 
211 template Status EncodeAudioAsS16LEWav<std::string>(const float* audio,
212                                                    size_t sample_rate,
213                                                    size_t num_channels,
214                                                    size_t num_frames,
215                                                    std::string* wav_string);
216 template Status EncodeAudioAsS16LEWav<tstring>(const float* audio,
217                                                size_t sample_rate,
218                                                size_t num_channels,
219                                                size_t num_frames,
220                                                tstring* wav_string);
221 
DecodeLin16WaveAsFloatVector(const std::string & wav_string,std::vector<float> * float_values,uint32 * sample_count,uint16 * channel_count,uint32 * sample_rate)222 Status DecodeLin16WaveAsFloatVector(const std::string& wav_string,
223                                     std::vector<float>* float_values,
224                                     uint32* sample_count, uint16* channel_count,
225                                     uint32* sample_rate) {
226   int offset = 0;
227   TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
228   uint32 total_file_size;
229   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset));
230   TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
231   TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset));
232   uint32 format_chunk_size;
233   TF_RETURN_IF_ERROR(
234       ReadValue<uint32>(wav_string, &format_chunk_size, &offset));
235   if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
236     return errors::InvalidArgument(
237         "Bad format chunk size for WAV: Expected 16 or 18, but got",
238         format_chunk_size);
239   }
240   uint16 audio_format;
241   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset));
242   if (audio_format != 1) {
243     return errors::InvalidArgument(
244         "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
245   }
246   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
247   if (*channel_count < 1) {
248     return errors::InvalidArgument(
249         "Bad number of channels for WAV: Expected at least 1, but got ",
250         *channel_count);
251   }
252   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
253   uint32 bytes_per_second;
254   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
255   uint16 bytes_per_sample;
256   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset));
257   // Confusingly, bits per sample is defined as holding the number of bits for
258   // one channel, unlike the definition of sample used elsewhere in the WAV
259   // spec. For example, bytes per sample is the memory needed for all channels
260   // for one point in time.
261   uint16 bits_per_sample;
262   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset));
263   if (bits_per_sample != 16) {
264     return errors::InvalidArgument(
265         "Can only read 16-bit WAV files, but received ", bits_per_sample);
266   }
267   const uint32 expected_bytes_per_sample =
268       ((bits_per_sample * *channel_count) + 7) / 8;
269   if (bytes_per_sample != expected_bytes_per_sample) {
270     return errors::InvalidArgument(
271         "Bad bytes per sample in WAV header: Expected ",
272         expected_bytes_per_sample, " but got ", bytes_per_sample);
273   }
274   const uint32 expected_bytes_per_second = bytes_per_sample * *sample_rate;
275   if (bytes_per_second != expected_bytes_per_second) {
276     return errors::InvalidArgument(
277         "Bad bytes per second in WAV header: Expected ",
278         expected_bytes_per_second, " but got ", bytes_per_second,
279         " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample,
280         ")");
281   }
282   if (format_chunk_size == 18) {
283     // Skip over this unused section.
284     offset += 2;
285   }
286 
287   bool was_data_found = false;
288   while (offset < wav_string.size()) {
289     std::string chunk_id;
290     TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
291     uint32 chunk_size;
292     TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
293     if (chunk_size > std::numeric_limits<int32>::max()) {
294       return errors::InvalidArgument(
295           "WAV data chunk '", chunk_id, "' is too large: ", chunk_size,
296           " bytes, but the limit is ", std::numeric_limits<int32>::max());
297     }
298     if (chunk_id == kDataChunkId) {
299       if (was_data_found) {
300         return errors::InvalidArgument("More than one data chunk found in WAV");
301       }
302       was_data_found = true;
303       *sample_count = chunk_size / bytes_per_sample;
304       const uint32 data_count = *sample_count * *channel_count;
305       int unused_new_offset = 0;
306       // Validate that the data exists before allocating space for it
307       // (prevent easy OOM errors).
308       TF_RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16) * data_count,
309                                          wav_string.size(),
310                                          &unused_new_offset));
311       float_values->resize(data_count);
312       for (int i = 0; i < data_count; ++i) {
313         int16_t single_channel_value = 0;
314         TF_RETURN_IF_ERROR(
315             ReadValue<int16>(wav_string, &single_channel_value, &offset));
316         (*float_values)[i] = Int16SampleToFloat(single_channel_value);
317       }
318     } else {
319       offset += chunk_size;
320     }
321   }
322   if (!was_data_found) {
323     return errors::InvalidArgument("No data chunk found in WAV");
324   }
325   return Status::OK();
326 }
327 
328 }  // namespace wav
329 }  // namespace tensorflow
330