• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Functions to write audio in WAV format.
17 
18 #include <math.h>
19 #include <string.h>
20 #include <algorithm>
21 
22 #include "absl/base/casts.h"
23 #include "tensorflow/core/lib/core/coding.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/wav/wav_io.h"
26 #include "tensorflow/core/platform/byte_order.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 
30 namespace tensorflow {
31 namespace wav {
32 namespace {
33 
34 struct TF_PACKED RiffChunk {
35   char chunk_id[4];
36   char chunk_data_size[4];
37   char riff_type[4];
38 };
39 static_assert(sizeof(RiffChunk) == 12, "TF_PACKED does not work.");
40 
41 struct TF_PACKED FormatChunk {
42   char chunk_id[4];
43   char chunk_data_size[4];
44   char compression_code[2];
45   char channel_numbers[2];
46   char sample_rate[4];
47   char bytes_per_second[4];
48   char bytes_per_frame[2];
49   char bits_per_sample[2];
50 };
51 static_assert(sizeof(FormatChunk) == 24, "TF_PACKED does not work.");
52 
53 struct TF_PACKED DataChunk {
54   char chunk_id[4];
55   char chunk_data_size[4];
56 };
57 static_assert(sizeof(DataChunk) == 8, "TF_PACKED does not work.");
58 
59 struct TF_PACKED WavHeader {
60   RiffChunk riff_chunk;
61   FormatChunk format_chunk;
62   DataChunk data_chunk;
63 };
64 static_assert(sizeof(WavHeader) ==
65                   sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk),
66               "TF_PACKED does not work.");
67 
68 constexpr char kRiffChunkId[] = "RIFF";
69 constexpr char kRiffType[] = "WAVE";
70 constexpr char kFormatChunkId[] = "fmt ";
71 constexpr char kDataChunkId[] = "data";
72 
FloatToInt16Sample(float data)73 inline int16 FloatToInt16Sample(float data) {
74   constexpr float kMultiplier = 1.0f * (1 << 15);
75   return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min),
76                          kint16max);
77 }
78 
Int16SampleToFloat(int16 data)79 inline float Int16SampleToFloat(int16 data) {
80   constexpr float kMultiplier = 1.0f / (1 << 15);
81   return data * kMultiplier;
82 }
83 
84 }  // namespace
85 
86 // Handles moving the data index forward, validating the arguments, and avoiding
87 // overflow or underflow.
IncrementOffset(int old_offset,size_t increment,size_t max_size,int * new_offset)88 Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
89                        int* new_offset) {
90   if (old_offset < 0) {
91     return errors::InvalidArgument("Negative offsets are not allowed: ",
92                                    old_offset);
93   }
94   if (old_offset > max_size) {
95     return errors::InvalidArgument("Initial offset is outside data range: ",
96                                    old_offset);
97   }
98   *new_offset = old_offset + increment;
99   if (*new_offset > max_size) {
100     return errors::InvalidArgument("Data too short when trying to read string");
101   }
102   // See above for the check that the input offset is positive. If it's negative
103   // here then it means that there's been an overflow in the arithmetic.
104   if (*new_offset < 0) {
105     return errors::InvalidArgument("Offset too large, overflowed: ",
106                                    *new_offset);
107   }
108   return Status::OK();
109 }
110 
ExpectText(const string & data,const string & expected_text,int * offset)111 Status ExpectText(const string& data, const string& expected_text,
112                   int* offset) {
113   int new_offset;
114   TF_RETURN_IF_ERROR(
115       IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
116   const string found_text(data.begin() + *offset, data.begin() + new_offset);
117   if (found_text != expected_text) {
118     return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
119                                    " but found ", found_text);
120   }
121   *offset = new_offset;
122   return Status::OK();
123 }
124 
ReadString(const string & data,int expected_length,string * value,int * offset)125 Status ReadString(const string& data, int expected_length, string* value,
126                   int* offset) {
127   int new_offset;
128   TF_RETURN_IF_ERROR(
129       IncrementOffset(*offset, expected_length, data.size(), &new_offset));
130   *value = string(data.begin() + *offset, data.begin() + new_offset);
131   *offset = new_offset;
132   return Status::OK();
133 }
134 
135 template <typename T>
EncodeAudioAsS16LEWav(const float * audio,size_t sample_rate,size_t num_channels,size_t num_frames,T * wav_string)136 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
137                              size_t num_channels, size_t num_frames,
138                              T* wav_string) {
139   constexpr size_t kFormatChunkSize = 16;
140   constexpr size_t kCompressionCodePcm = 1;
141   constexpr size_t kBitsPerSample = 16;
142   constexpr size_t kBytesPerSample = kBitsPerSample / 8;
143   constexpr size_t kHeaderSize = sizeof(WavHeader);
144 
145   if (audio == nullptr) {
146     return errors::InvalidArgument("audio is null");
147   }
148   if (wav_string == nullptr) {
149     return errors::InvalidArgument("wav_string is null");
150   }
151   if (sample_rate == 0 || sample_rate > kuint32max) {
152     return errors::InvalidArgument("sample_rate must be in (0, 2^32), got: ",
153                                    sample_rate);
154   }
155   if (num_channels == 0 || num_channels > kuint16max) {
156     return errors::InvalidArgument("num_channels must be in (0, 2^16), got: ",
157                                    num_channels);
158   }
159   if (num_frames == 0) {
160     return errors::InvalidArgument("num_frames must be positive.");
161   }
162 
163   const size_t bytes_per_second = sample_rate * kBytesPerSample * num_channels;
164   const size_t num_samples = num_frames * num_channels;
165   const size_t data_size = num_samples * kBytesPerSample;
166   const size_t file_size = kHeaderSize + num_samples * kBytesPerSample;
167   const size_t bytes_per_frame = kBytesPerSample * num_channels;
168 
169   // WAV represents the length of the file as a uint32 so file_size cannot
170   // exceed kuint32max.
171   if (file_size > kuint32max) {
172     return errors::InvalidArgument(
173         "Provided channels and frames cannot be encoded as a WAV.");
174   }
175 
176   wav_string->resize(file_size);
177   char* data = &(*wav_string)[0];
178   WavHeader* header = absl::bit_cast<WavHeader*>(data);
179 
180   // Fill RIFF chunk.
181   auto* riff_chunk = &header->riff_chunk;
182   memcpy(riff_chunk->chunk_id, kRiffChunkId, 4);
183   core::EncodeFixed32(riff_chunk->chunk_data_size, file_size - 8);
184   memcpy(riff_chunk->riff_type, kRiffType, 4);
185 
186   // Fill format chunk.
187   auto* format_chunk = &header->format_chunk;
188   memcpy(format_chunk->chunk_id, kFormatChunkId, 4);
189   core::EncodeFixed32(format_chunk->chunk_data_size, kFormatChunkSize);
190   core::EncodeFixed16(format_chunk->compression_code, kCompressionCodePcm);
191   core::EncodeFixed16(format_chunk->channel_numbers, num_channels);
192   core::EncodeFixed32(format_chunk->sample_rate, sample_rate);
193   core::EncodeFixed32(format_chunk->bytes_per_second, bytes_per_second);
194   core::EncodeFixed16(format_chunk->bytes_per_frame, bytes_per_frame);
195   core::EncodeFixed16(format_chunk->bits_per_sample, kBitsPerSample);
196 
197   // Fill data chunk.
198   auto* data_chunk = &header->data_chunk;
199   memcpy(data_chunk->chunk_id, kDataChunkId, 4);
200   core::EncodeFixed32(data_chunk->chunk_data_size, data_size);
201 
202   // Write the audio.
203   data += kHeaderSize;
204   for (size_t i = 0; i < num_samples; ++i) {
205     int16 sample = FloatToInt16Sample(audio[i]);
206     core::EncodeFixed16(&data[i * kBytesPerSample],
207                         static_cast<uint16>(sample));
208   }
209   return Status::OK();
210 }
211 
212 template Status EncodeAudioAsS16LEWav<string>(const float* audio,
213                                               size_t sample_rate,
214                                               size_t num_channels,
215                                               size_t num_frames,
216                                               string* wav_string);
217 template Status EncodeAudioAsS16LEWav<tstring>(const float* audio,
218                                                size_t sample_rate,
219                                                size_t num_channels,
220                                                size_t num_frames,
221                                                tstring* wav_string);
222 
DecodeLin16WaveAsFloatVector(const string & wav_string,std::vector<float> * float_values,uint32 * sample_count,uint16 * channel_count,uint32 * sample_rate)223 Status DecodeLin16WaveAsFloatVector(const string& wav_string,
224                                     std::vector<float>* float_values,
225                                     uint32* sample_count, uint16* channel_count,
226                                     uint32* sample_rate) {
227   int offset = 0;
228   TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
229   uint32 total_file_size;
230   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset));
231   TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
232   TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset));
233   uint32 format_chunk_size;
234   TF_RETURN_IF_ERROR(
235       ReadValue<uint32>(wav_string, &format_chunk_size, &offset));
236   if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
237     return errors::InvalidArgument(
238         "Bad format chunk size for WAV: Expected 16 or 18, but got",
239         format_chunk_size);
240   }
241   uint16 audio_format;
242   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset));
243   if (audio_format != 1) {
244     return errors::InvalidArgument(
245         "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
246   }
247   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
248   if (*channel_count < 1) {
249     return errors::InvalidArgument(
250         "Bad number of channels for WAV: Expected at least 1, but got ",
251         *channel_count);
252   }
253   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
254   uint32 bytes_per_second;
255   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
256   uint16 bytes_per_sample;
257   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset));
258   // Confusingly, bits per sample is defined as holding the number of bits for
259   // one channel, unlike the definition of sample used elsewhere in the WAV
260   // spec. For example, bytes per sample is the memory needed for all channels
261   // for one point in time.
262   uint16 bits_per_sample;
263   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset));
264   if (bits_per_sample != 16) {
265     return errors::InvalidArgument(
266         "Can only read 16-bit WAV files, but received ", bits_per_sample);
267   }
268   const uint32 expected_bytes_per_sample =
269       ((bits_per_sample * *channel_count) + 7) / 8;
270   if (bytes_per_sample != expected_bytes_per_sample) {
271     return errors::InvalidArgument(
272         "Bad bytes per sample in WAV header: Expected ",
273         expected_bytes_per_sample, " but got ", bytes_per_sample);
274   }
275   const uint32 expected_bytes_per_second = bytes_per_sample * *sample_rate;
276   if (bytes_per_second != expected_bytes_per_second) {
277     return errors::InvalidArgument(
278         "Bad bytes per second in WAV header: Expected ",
279         expected_bytes_per_second, " but got ", bytes_per_second,
280         " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample,
281         ")");
282   }
283   if (format_chunk_size == 18) {
284     // Skip over this unused section.
285     offset += 2;
286   }
287 
288   bool was_data_found = false;
289   while (offset < wav_string.size()) {
290     string chunk_id;
291     TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
292     uint32 chunk_size;
293     TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
294     if (chunk_size > std::numeric_limits<int32>::max()) {
295       return errors::InvalidArgument(
296           "WAV data chunk '", chunk_id, "' is too large: ", chunk_size,
297           " bytes, but the limit is ", std::numeric_limits<int32>::max());
298     }
299     if (chunk_id == kDataChunkId) {
300       if (was_data_found) {
301         return errors::InvalidArgument("More than one data chunk found in WAV");
302       }
303       was_data_found = true;
304       *sample_count = chunk_size / bytes_per_sample;
305       const uint32 data_count = *sample_count * *channel_count;
306       int unused_new_offset = 0;
307       // Validate that the data exists before allocating space for it
308       // (prevent easy OOM errors).
309       TF_RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16) * data_count,
310                                          wav_string.size(),
311                                          &unused_new_offset));
312       float_values->resize(data_count);
313       for (int i = 0; i < data_count; ++i) {
314         int16 single_channel_value = 0;
315         TF_RETURN_IF_ERROR(
316             ReadValue<int16>(wav_string, &single_channel_value, &offset));
317         (*float_values)[i] = Int16SampleToFloat(single_channel_value);
318       }
319     } else {
320       offset += chunk_size;
321     }
322   }
323   if (!was_data_found) {
324     return errors::InvalidArgument("No data chunk found in WAV");
325   }
326   return Status::OK();
327 }
328 
329 }  // namespace wav
330 }  // namespace tensorflow
331