1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Functions to write audio in WAV format.
17
18 #include <math.h>
19 #include <string.h>
20 #include <algorithm>
21
22 #include "absl/base/casts.h"
23 #include "tensorflow/core/lib/core/coding.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/wav/wav_io.h"
26 #include "tensorflow/core/platform/byte_order.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29
30 namespace tensorflow {
31 namespace wav {
32 namespace {
33
34 struct TF_PACKED RiffChunk {
35 char chunk_id[4];
36 char chunk_data_size[4];
37 char riff_type[4];
38 };
39 static_assert(sizeof(RiffChunk) == 12, "TF_PACKED does not work.");
40
41 struct TF_PACKED FormatChunk {
42 char chunk_id[4];
43 char chunk_data_size[4];
44 char compression_code[2];
45 char channel_numbers[2];
46 char sample_rate[4];
47 char bytes_per_second[4];
48 char bytes_per_frame[2];
49 char bits_per_sample[2];
50 };
51 static_assert(sizeof(FormatChunk) == 24, "TF_PACKED does not work.");
52
53 struct TF_PACKED DataChunk {
54 char chunk_id[4];
55 char chunk_data_size[4];
56 };
57 static_assert(sizeof(DataChunk) == 8, "TF_PACKED does not work.");
58
59 struct TF_PACKED WavHeader {
60 RiffChunk riff_chunk;
61 FormatChunk format_chunk;
62 DataChunk data_chunk;
63 };
64 static_assert(sizeof(WavHeader) ==
65 sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk),
66 "TF_PACKED does not work.");
67
68 constexpr char kRiffChunkId[] = "RIFF";
69 constexpr char kRiffType[] = "WAVE";
70 constexpr char kFormatChunkId[] = "fmt ";
71 constexpr char kDataChunkId[] = "data";
72
FloatToInt16Sample(float data)73 inline int16 FloatToInt16Sample(float data) {
74 constexpr float kMultiplier = 1.0f * (1 << 15);
75 return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min),
76 kint16max);
77 }
78
Int16SampleToFloat(int16_t data)79 inline float Int16SampleToFloat(int16_t data) {
80 constexpr float kMultiplier = 1.0f / (1 << 15);
81 return data * kMultiplier;
82 }
83
84 } // namespace
85
86 // Handles moving the data index forward, validating the arguments, and avoiding
87 // overflow or underflow.
IncrementOffset(int old_offset,int64_t increment,size_t max_size,int * new_offset)88 Status IncrementOffset(int old_offset, int64_t increment, size_t max_size,
89 int* new_offset) {
90 if (old_offset < 0) {
91 return errors::InvalidArgument("Negative offsets are not allowed: ",
92 old_offset);
93 }
94 if (increment < 0) {
95 return errors::InvalidArgument("Negative increment is not allowed: ",
96 increment);
97 }
98 if (old_offset > max_size) {
99 return errors::InvalidArgument("Initial offset is outside data range: ",
100 old_offset);
101 }
102 *new_offset = old_offset + increment;
103 if (*new_offset > max_size) {
104 return errors::InvalidArgument("Data too short when trying to read string");
105 }
106 // See above for the check that the input offset is positive. If it's negative
107 // here then it means that there's been an overflow in the arithmetic.
108 if (*new_offset < 0) {
109 return errors::InvalidArgument("Offset too large, overflowed: ",
110 *new_offset);
111 }
112 return OkStatus();
113 }
114
ExpectText(const std::string & data,const std::string & expected_text,int * offset)115 Status ExpectText(const std::string& data, const std::string& expected_text,
116 int* offset) {
117 int new_offset;
118 TF_RETURN_IF_ERROR(
119 IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
120 const std::string found_text(data.begin() + *offset,
121 data.begin() + new_offset);
122 if (found_text != expected_text) {
123 return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
124 " but found ", found_text);
125 }
126 *offset = new_offset;
127 return OkStatus();
128 }
129
ReadString(const std::string & data,int expected_length,std::string * value,int * offset)130 Status ReadString(const std::string& data, int expected_length,
131 std::string* value, int* offset) {
132 int new_offset;
133 TF_RETURN_IF_ERROR(
134 IncrementOffset(*offset, expected_length, data.size(), &new_offset));
135 *value = std::string(data.begin() + *offset, data.begin() + new_offset);
136 *offset = new_offset;
137 return OkStatus();
138 }
139
140 template <typename T>
EncodeAudioAsS16LEWav(const float * audio,size_t sample_rate,size_t num_channels,size_t num_frames,T * wav_string)141 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
142 size_t num_channels, size_t num_frames,
143 T* wav_string) {
144 constexpr size_t kFormatChunkSize = 16;
145 constexpr size_t kCompressionCodePcm = 1;
146 constexpr size_t kBitsPerSample = 16;
147 constexpr size_t kBytesPerSample = kBitsPerSample / 8;
148 constexpr size_t kHeaderSize = sizeof(WavHeader);
149
150 // If num_frames is zero, audio can be nullptr.
151 if (audio == nullptr && num_frames > 0) {
152 return errors::InvalidArgument("audio is null");
153 }
154 if (wav_string == nullptr) {
155 return errors::InvalidArgument("wav_string is null");
156 }
157 if (sample_rate == 0 || sample_rate > kuint32max) {
158 return errors::InvalidArgument("sample_rate must be in (0, 2^32), got: ",
159 sample_rate);
160 }
161 if (num_channels == 0 || num_channels > kuint16max) {
162 return errors::InvalidArgument("num_channels must be in (0, 2^16), got: ",
163 num_channels);
164 }
165
166 const size_t bytes_per_second = sample_rate * kBytesPerSample * num_channels;
167 const size_t num_samples = num_frames * num_channels;
168 const size_t data_size = num_samples * kBytesPerSample;
169 const size_t file_size = kHeaderSize + num_samples * kBytesPerSample;
170 const size_t bytes_per_frame = kBytesPerSample * num_channels;
171
172 // WAV represents the length of the file as a uint32 so file_size cannot
173 // exceed kuint32max.
174 if (file_size > kuint32max) {
175 return errors::InvalidArgument(
176 "Provided channels and frames cannot be encoded as a WAV.");
177 }
178
179 wav_string->resize(file_size);
180 char* data = &(*wav_string)[0];
181 WavHeader* header = absl::bit_cast<WavHeader*>(data);
182
183 // Fill RIFF chunk.
184 auto* riff_chunk = &header->riff_chunk;
185 memcpy(riff_chunk->chunk_id, kRiffChunkId, 4);
186 core::EncodeFixed32(riff_chunk->chunk_data_size, file_size - 8);
187 memcpy(riff_chunk->riff_type, kRiffType, 4);
188
189 // Fill format chunk.
190 auto* format_chunk = &header->format_chunk;
191 memcpy(format_chunk->chunk_id, kFormatChunkId, 4);
192 core::EncodeFixed32(format_chunk->chunk_data_size, kFormatChunkSize);
193 core::EncodeFixed16(format_chunk->compression_code, kCompressionCodePcm);
194 core::EncodeFixed16(format_chunk->channel_numbers, num_channels);
195 core::EncodeFixed32(format_chunk->sample_rate, sample_rate);
196 core::EncodeFixed32(format_chunk->bytes_per_second, bytes_per_second);
197 core::EncodeFixed16(format_chunk->bytes_per_frame, bytes_per_frame);
198 core::EncodeFixed16(format_chunk->bits_per_sample, kBitsPerSample);
199
200 // Fill data chunk.
201 auto* data_chunk = &header->data_chunk;
202 memcpy(data_chunk->chunk_id, kDataChunkId, 4);
203 core::EncodeFixed32(data_chunk->chunk_data_size, data_size);
204
205 // Write the audio.
206 data += kHeaderSize;
207 for (size_t i = 0; i < num_samples; ++i) {
208 int16_t sample = FloatToInt16Sample(audio[i]);
209 core::EncodeFixed16(&data[i * kBytesPerSample],
210 static_cast<uint16>(sample));
211 }
212 return OkStatus();
213 }
214
215 template Status EncodeAudioAsS16LEWav<std::string>(const float* audio,
216 size_t sample_rate,
217 size_t num_channels,
218 size_t num_frames,
219 std::string* wav_string);
220 template Status EncodeAudioAsS16LEWav<tstring>(const float* audio,
221 size_t sample_rate,
222 size_t num_channels,
223 size_t num_frames,
224 tstring* wav_string);
225
DecodeLin16WaveAsFloatVector(const std::string & wav_string,std::vector<float> * float_values,uint32 * sample_count,uint16 * channel_count,uint32 * sample_rate)226 Status DecodeLin16WaveAsFloatVector(const std::string& wav_string,
227 std::vector<float>* float_values,
228 uint32* sample_count, uint16* channel_count,
229 uint32* sample_rate) {
230 int offset = 0;
231 TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
232 uint32 total_file_size;
233 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset));
234 TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
235 std::string found_text;
236 TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &found_text, &offset));
237 while (found_text != kFormatChunkId) {
238 // Padding chunk may occur between "WAVE" and "fmt ".
239 // Skip JUNK/bext/etc field to support for WAV file with either JUNK Chunk,
240 // or broadcast WAV where additional tags might appear.
241 // Reference: the implementation of tfio in audio_video_wav_kernels.cc,
242 // https://www.daubnet.com/en/file-format-riff,
243 // https://en.wikipedia.org/wiki/Broadcast_Wave_Format
244 if (found_text != "JUNK" && found_text != "bext" && found_text != "iXML" &&
245 found_text != "qlty" && found_text != "mext" && found_text != "levl" &&
246 found_text != "link" && found_text != "axml") {
247 return errors::InvalidArgument("Unexpected field ", found_text);
248 }
249 uint32 size_of_chunk;
250 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &size_of_chunk, &offset));
251 TF_RETURN_IF_ERROR(
252 IncrementOffset(offset, size_of_chunk, wav_string.size(), &offset));
253 TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &found_text, &offset));
254 }
255 uint32 format_chunk_size;
256 TF_RETURN_IF_ERROR(
257 ReadValue<uint32>(wav_string, &format_chunk_size, &offset));
258 if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
259 return errors::InvalidArgument(
260 "Bad format chunk size for WAV: Expected 16 or 18, but got",
261 format_chunk_size);
262 }
263 uint16 audio_format;
264 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset));
265 if (audio_format != 1) {
266 return errors::InvalidArgument(
267 "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
268 }
269 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
270 if (*channel_count < 1) {
271 return errors::InvalidArgument(
272 "Bad number of channels for WAV: Expected at least 1, but got ",
273 *channel_count);
274 }
275 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
276 uint32 bytes_per_second;
277 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
278 uint16 bytes_per_sample;
279 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset));
280 // Confusingly, bits per sample is defined as holding the number of bits for
281 // one channel, unlike the definition of sample used elsewhere in the WAV
282 // spec. For example, bytes per sample is the memory needed for all channels
283 // for one point in time.
284 uint16 bits_per_sample;
285 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset));
286 if (bits_per_sample != 16) {
287 return errors::InvalidArgument(
288 "Can only read 16-bit WAV files, but received ", bits_per_sample);
289 }
290 const uint32 expected_bytes_per_sample =
291 ((bits_per_sample * *channel_count) + 7) / 8;
292 if (bytes_per_sample != expected_bytes_per_sample) {
293 return errors::InvalidArgument(
294 "Bad bytes per sample in WAV header: Expected ",
295 expected_bytes_per_sample, " but got ", bytes_per_sample);
296 }
297 const uint32 expected_bytes_per_second = bytes_per_sample * *sample_rate;
298 if (bytes_per_second != expected_bytes_per_second) {
299 return errors::InvalidArgument(
300 "Bad bytes per second in WAV header: Expected ",
301 expected_bytes_per_second, " but got ", bytes_per_second,
302 " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample,
303 ")");
304 }
305 if (format_chunk_size == 18) {
306 // Skip over this unused section.
307 offset += 2;
308 }
309
310 bool was_data_found = false;
311 while (offset < wav_string.size()) {
312 std::string chunk_id;
313 TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
314 uint32 chunk_size;
315 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
316 if (chunk_size > std::numeric_limits<int32>::max()) {
317 return errors::InvalidArgument(
318 "WAV data chunk '", chunk_id, "' is too large: ", chunk_size,
319 " bytes, but the limit is ", std::numeric_limits<int32>::max());
320 }
321 if (chunk_id == kDataChunkId) {
322 if (was_data_found) {
323 return errors::InvalidArgument("More than one data chunk found in WAV");
324 }
325 was_data_found = true;
326 *sample_count = chunk_size / bytes_per_sample;
327 const uint32 data_count = *sample_count * *channel_count;
328 int unused_new_offset = 0;
329 // Validate that the data exists before allocating space for it
330 // (prevent easy OOM errors).
331 TF_RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16) * data_count,
332 wav_string.size(),
333 &unused_new_offset));
334 float_values->resize(data_count);
335 for (int i = 0; i < data_count; ++i) {
336 int16_t single_channel_value = 0;
337 TF_RETURN_IF_ERROR(
338 ReadValue<int16>(wav_string, &single_channel_value, &offset));
339 (*float_values)[i] = Int16SampleToFloat(single_channel_value);
340 }
341 } else {
342 offset += chunk_size;
343 }
344 }
345 if (!was_data_found) {
346 return errors::InvalidArgument("No data chunk found in WAV");
347 }
348 return OkStatus();
349 }
350
351 } // namespace wav
352 } // namespace tensorflow
353