1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Functions to write audio in WAV format.
17
18 #include <math.h>
19 #include <string.h>
20 #include <algorithm>
21
22 #include "absl/base/casts.h"
23 #include "tensorflow/core/lib/core/coding.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/wav/wav_io.h"
26 #include "tensorflow/core/platform/byte_order.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29
30 namespace tensorflow {
31 namespace wav {
32 namespace {
33
34 struct TF_PACKED RiffChunk {
35 char chunk_id[4];
36 char chunk_data_size[4];
37 char riff_type[4];
38 };
39 static_assert(sizeof(RiffChunk) == 12, "TF_PACKED does not work.");
40
41 struct TF_PACKED FormatChunk {
42 char chunk_id[4];
43 char chunk_data_size[4];
44 char compression_code[2];
45 char channel_numbers[2];
46 char sample_rate[4];
47 char bytes_per_second[4];
48 char bytes_per_frame[2];
49 char bits_per_sample[2];
50 };
51 static_assert(sizeof(FormatChunk) == 24, "TF_PACKED does not work.");
52
53 struct TF_PACKED DataChunk {
54 char chunk_id[4];
55 char chunk_data_size[4];
56 };
57 static_assert(sizeof(DataChunk) == 8, "TF_PACKED does not work.");
58
59 struct TF_PACKED WavHeader {
60 RiffChunk riff_chunk;
61 FormatChunk format_chunk;
62 DataChunk data_chunk;
63 };
64 static_assert(sizeof(WavHeader) ==
65 sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk),
66 "TF_PACKED does not work.");
67
68 constexpr char kRiffChunkId[] = "RIFF";
69 constexpr char kRiffType[] = "WAVE";
70 constexpr char kFormatChunkId[] = "fmt ";
71 constexpr char kDataChunkId[] = "data";
72
FloatToInt16Sample(float data)73 inline int16 FloatToInt16Sample(float data) {
74 constexpr float kMultiplier = 1.0f * (1 << 15);
75 return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min),
76 kint16max);
77 }
78
Int16SampleToFloat(int16 data)79 inline float Int16SampleToFloat(int16 data) {
80 constexpr float kMultiplier = 1.0f / (1 << 15);
81 return data * kMultiplier;
82 }
83
84 } // namespace
85
86 // Handles moving the data index forward, validating the arguments, and avoiding
87 // overflow or underflow.
IncrementOffset(int old_offset,size_t increment,size_t max_size,int * new_offset)88 Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
89 int* new_offset) {
90 if (old_offset < 0) {
91 return errors::InvalidArgument("Negative offsets are not allowed: ",
92 old_offset);
93 }
94 if (old_offset > max_size) {
95 return errors::InvalidArgument("Initial offset is outside data range: ",
96 old_offset);
97 }
98 *new_offset = old_offset + increment;
99 if (*new_offset > max_size) {
100 return errors::InvalidArgument("Data too short when trying to read string");
101 }
102 // See above for the check that the input offset is positive. If it's negative
103 // here then it means that there's been an overflow in the arithmetic.
104 if (*new_offset < 0) {
105 return errors::InvalidArgument("Offset too large, overflowed: ",
106 *new_offset);
107 }
108 return Status::OK();
109 }
110
ExpectText(const string & data,const string & expected_text,int * offset)111 Status ExpectText(const string& data, const string& expected_text,
112 int* offset) {
113 int new_offset;
114 TF_RETURN_IF_ERROR(
115 IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
116 const string found_text(data.begin() + *offset, data.begin() + new_offset);
117 if (found_text != expected_text) {
118 return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
119 " but found ", found_text);
120 }
121 *offset = new_offset;
122 return Status::OK();
123 }
124
ReadString(const string & data,int expected_length,string * value,int * offset)125 Status ReadString(const string& data, int expected_length, string* value,
126 int* offset) {
127 int new_offset;
128 TF_RETURN_IF_ERROR(
129 IncrementOffset(*offset, expected_length, data.size(), &new_offset));
130 *value = string(data.begin() + *offset, data.begin() + new_offset);
131 *offset = new_offset;
132 return Status::OK();
133 }
134
135 template <typename T>
EncodeAudioAsS16LEWav(const float * audio,size_t sample_rate,size_t num_channels,size_t num_frames,T * wav_string)136 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
137 size_t num_channels, size_t num_frames,
138 T* wav_string) {
139 constexpr size_t kFormatChunkSize = 16;
140 constexpr size_t kCompressionCodePcm = 1;
141 constexpr size_t kBitsPerSample = 16;
142 constexpr size_t kBytesPerSample = kBitsPerSample / 8;
143 constexpr size_t kHeaderSize = sizeof(WavHeader);
144
145 if (audio == nullptr) {
146 return errors::InvalidArgument("audio is null");
147 }
148 if (wav_string == nullptr) {
149 return errors::InvalidArgument("wav_string is null");
150 }
151 if (sample_rate == 0 || sample_rate > kuint32max) {
152 return errors::InvalidArgument("sample_rate must be in (0, 2^32), got: ",
153 sample_rate);
154 }
155 if (num_channels == 0 || num_channels > kuint16max) {
156 return errors::InvalidArgument("num_channels must be in (0, 2^16), got: ",
157 num_channels);
158 }
159 if (num_frames == 0) {
160 return errors::InvalidArgument("num_frames must be positive.");
161 }
162
163 const size_t bytes_per_second = sample_rate * kBytesPerSample * num_channels;
164 const size_t num_samples = num_frames * num_channels;
165 const size_t data_size = num_samples * kBytesPerSample;
166 const size_t file_size = kHeaderSize + num_samples * kBytesPerSample;
167 const size_t bytes_per_frame = kBytesPerSample * num_channels;
168
169 // WAV represents the length of the file as a uint32 so file_size cannot
170 // exceed kuint32max.
171 if (file_size > kuint32max) {
172 return errors::InvalidArgument(
173 "Provided channels and frames cannot be encoded as a WAV.");
174 }
175
176 wav_string->resize(file_size);
177 char* data = &(*wav_string)[0];
178 WavHeader* header = absl::bit_cast<WavHeader*>(data);
179
180 // Fill RIFF chunk.
181 auto* riff_chunk = &header->riff_chunk;
182 memcpy(riff_chunk->chunk_id, kRiffChunkId, 4);
183 core::EncodeFixed32(riff_chunk->chunk_data_size, file_size - 8);
184 memcpy(riff_chunk->riff_type, kRiffType, 4);
185
186 // Fill format chunk.
187 auto* format_chunk = &header->format_chunk;
188 memcpy(format_chunk->chunk_id, kFormatChunkId, 4);
189 core::EncodeFixed32(format_chunk->chunk_data_size, kFormatChunkSize);
190 core::EncodeFixed16(format_chunk->compression_code, kCompressionCodePcm);
191 core::EncodeFixed16(format_chunk->channel_numbers, num_channels);
192 core::EncodeFixed32(format_chunk->sample_rate, sample_rate);
193 core::EncodeFixed32(format_chunk->bytes_per_second, bytes_per_second);
194 core::EncodeFixed16(format_chunk->bytes_per_frame, bytes_per_frame);
195 core::EncodeFixed16(format_chunk->bits_per_sample, kBitsPerSample);
196
197 // Fill data chunk.
198 auto* data_chunk = &header->data_chunk;
199 memcpy(data_chunk->chunk_id, kDataChunkId, 4);
200 core::EncodeFixed32(data_chunk->chunk_data_size, data_size);
201
202 // Write the audio.
203 data += kHeaderSize;
204 for (size_t i = 0; i < num_samples; ++i) {
205 int16 sample = FloatToInt16Sample(audio[i]);
206 core::EncodeFixed16(&data[i * kBytesPerSample],
207 static_cast<uint16>(sample));
208 }
209 return Status::OK();
210 }
211
212 template Status EncodeAudioAsS16LEWav<string>(const float* audio,
213 size_t sample_rate,
214 size_t num_channels,
215 size_t num_frames,
216 string* wav_string);
217 template Status EncodeAudioAsS16LEWav<tstring>(const float* audio,
218 size_t sample_rate,
219 size_t num_channels,
220 size_t num_frames,
221 tstring* wav_string);
222
DecodeLin16WaveAsFloatVector(const string & wav_string,std::vector<float> * float_values,uint32 * sample_count,uint16 * channel_count,uint32 * sample_rate)223 Status DecodeLin16WaveAsFloatVector(const string& wav_string,
224 std::vector<float>* float_values,
225 uint32* sample_count, uint16* channel_count,
226 uint32* sample_rate) {
227 int offset = 0;
228 TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
229 uint32 total_file_size;
230 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset));
231 TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
232 TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset));
233 uint32 format_chunk_size;
234 TF_RETURN_IF_ERROR(
235 ReadValue<uint32>(wav_string, &format_chunk_size, &offset));
236 if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
237 return errors::InvalidArgument(
238 "Bad format chunk size for WAV: Expected 16 or 18, but got",
239 format_chunk_size);
240 }
241 uint16 audio_format;
242 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset));
243 if (audio_format != 1) {
244 return errors::InvalidArgument(
245 "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
246 }
247 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
248 if (*channel_count < 1) {
249 return errors::InvalidArgument(
250 "Bad number of channels for WAV: Expected at least 1, but got ",
251 *channel_count);
252 }
253 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
254 uint32 bytes_per_second;
255 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
256 uint16 bytes_per_sample;
257 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset));
258 // Confusingly, bits per sample is defined as holding the number of bits for
259 // one channel, unlike the definition of sample used elsewhere in the WAV
260 // spec. For example, bytes per sample is the memory needed for all channels
261 // for one point in time.
262 uint16 bits_per_sample;
263 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset));
264 if (bits_per_sample != 16) {
265 return errors::InvalidArgument(
266 "Can only read 16-bit WAV files, but received ", bits_per_sample);
267 }
268 const uint32 expected_bytes_per_sample =
269 ((bits_per_sample * *channel_count) + 7) / 8;
270 if (bytes_per_sample != expected_bytes_per_sample) {
271 return errors::InvalidArgument(
272 "Bad bytes per sample in WAV header: Expected ",
273 expected_bytes_per_sample, " but got ", bytes_per_sample);
274 }
275 const uint32 expected_bytes_per_second = bytes_per_sample * *sample_rate;
276 if (bytes_per_second != expected_bytes_per_second) {
277 return errors::InvalidArgument(
278 "Bad bytes per second in WAV header: Expected ",
279 expected_bytes_per_second, " but got ", bytes_per_second,
280 " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample,
281 ")");
282 }
283 if (format_chunk_size == 18) {
284 // Skip over this unused section.
285 offset += 2;
286 }
287
288 bool was_data_found = false;
289 while (offset < wav_string.size()) {
290 string chunk_id;
291 TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
292 uint32 chunk_size;
293 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
294 if (chunk_size > std::numeric_limits<int32>::max()) {
295 return errors::InvalidArgument(
296 "WAV data chunk '", chunk_id, "' is too large: ", chunk_size,
297 " bytes, but the limit is ", std::numeric_limits<int32>::max());
298 }
299 if (chunk_id == kDataChunkId) {
300 if (was_data_found) {
301 return errors::InvalidArgument("More than one data chunk found in WAV");
302 }
303 was_data_found = true;
304 *sample_count = chunk_size / bytes_per_sample;
305 const uint32 data_count = *sample_count * *channel_count;
306 int unused_new_offset = 0;
307 // Validate that the data exists before allocating space for it
308 // (prevent easy OOM errors).
309 TF_RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16) * data_count,
310 wav_string.size(),
311 &unused_new_offset));
312 float_values->resize(data_count);
313 for (int i = 0; i < data_count; ++i) {
314 int16 single_channel_value = 0;
315 TF_RETURN_IF_ERROR(
316 ReadValue<int16>(wav_string, &single_channel_value, &offset));
317 (*float_values)[i] = Int16SampleToFloat(single_channel_value);
318 }
319 } else {
320 offset += chunk_size;
321 }
322 }
323 if (!was_data_found) {
324 return errors::InvalidArgument("No data chunk found in WAV");
325 }
326 return Status::OK();
327 }
328
329 } // namespace wav
330 } // namespace tensorflow
331