1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Functions to write audio in WAV format.
17
18 #include <math.h>
19 #include <string.h>
20 #include <algorithm>
21
22 #include "absl/base/casts.h"
23 #include "tensorflow/core/lib/core/coding.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/wav/wav_io.h"
26 #include "tensorflow/core/platform/byte_order.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29
30 namespace tensorflow {
31 namespace wav {
32 namespace {
33
34 struct TF_PACKED RiffChunk {
35 char chunk_id[4];
36 char chunk_data_size[4];
37 char riff_type[4];
38 };
39 static_assert(sizeof(RiffChunk) == 12, "TF_PACKED does not work.");
40
41 struct TF_PACKED FormatChunk {
42 char chunk_id[4];
43 char chunk_data_size[4];
44 char compression_code[2];
45 char channel_numbers[2];
46 char sample_rate[4];
47 char bytes_per_second[4];
48 char bytes_per_frame[2];
49 char bits_per_sample[2];
50 };
51 static_assert(sizeof(FormatChunk) == 24, "TF_PACKED does not work.");
52
53 struct TF_PACKED DataChunk {
54 char chunk_id[4];
55 char chunk_data_size[4];
56 };
57 static_assert(sizeof(DataChunk) == 8, "TF_PACKED does not work.");
58
59 struct TF_PACKED WavHeader {
60 RiffChunk riff_chunk;
61 FormatChunk format_chunk;
62 DataChunk data_chunk;
63 };
64 static_assert(sizeof(WavHeader) ==
65 sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk),
66 "TF_PACKED does not work.");
67
68 constexpr char kRiffChunkId[] = "RIFF";
69 constexpr char kRiffType[] = "WAVE";
70 constexpr char kFormatChunkId[] = "fmt ";
71 constexpr char kDataChunkId[] = "data";
72
FloatToInt16Sample(float data)73 inline int16 FloatToInt16Sample(float data) {
74 constexpr float kMultiplier = 1.0f * (1 << 15);
75 return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min),
76 kint16max);
77 }
78
Int16SampleToFloat(int16_t data)79 inline float Int16SampleToFloat(int16_t data) {
80 constexpr float kMultiplier = 1.0f / (1 << 15);
81 return data * kMultiplier;
82 }
83
84 } // namespace
85
86 // Handles moving the data index forward, validating the arguments, and avoiding
87 // overflow or underflow.
IncrementOffset(int old_offset,size_t increment,size_t max_size,int * new_offset)88 Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
89 int* new_offset) {
90 if (old_offset < 0) {
91 return errors::InvalidArgument("Negative offsets are not allowed: ",
92 old_offset);
93 }
94 if (old_offset > max_size) {
95 return errors::InvalidArgument("Initial offset is outside data range: ",
96 old_offset);
97 }
98 *new_offset = old_offset + increment;
99 if (*new_offset > max_size) {
100 return errors::InvalidArgument("Data too short when trying to read string");
101 }
102 // See above for the check that the input offset is positive. If it's negative
103 // here then it means that there's been an overflow in the arithmetic.
104 if (*new_offset < 0) {
105 return errors::InvalidArgument("Offset too large, overflowed: ",
106 *new_offset);
107 }
108 return Status::OK();
109 }
110
ExpectText(const std::string & data,const std::string & expected_text,int * offset)111 Status ExpectText(const std::string& data, const std::string& expected_text,
112 int* offset) {
113 int new_offset;
114 TF_RETURN_IF_ERROR(
115 IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
116 const std::string found_text(data.begin() + *offset,
117 data.begin() + new_offset);
118 if (found_text != expected_text) {
119 return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
120 " but found ", found_text);
121 }
122 *offset = new_offset;
123 return Status::OK();
124 }
125
ReadString(const std::string & data,int expected_length,std::string * value,int * offset)126 Status ReadString(const std::string& data, int expected_length,
127 std::string* value, int* offset) {
128 int new_offset;
129 TF_RETURN_IF_ERROR(
130 IncrementOffset(*offset, expected_length, data.size(), &new_offset));
131 *value = std::string(data.begin() + *offset, data.begin() + new_offset);
132 *offset = new_offset;
133 return Status::OK();
134 }
135
136 template <typename T>
EncodeAudioAsS16LEWav(const float * audio,size_t sample_rate,size_t num_channels,size_t num_frames,T * wav_string)137 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
138 size_t num_channels, size_t num_frames,
139 T* wav_string) {
140 constexpr size_t kFormatChunkSize = 16;
141 constexpr size_t kCompressionCodePcm = 1;
142 constexpr size_t kBitsPerSample = 16;
143 constexpr size_t kBytesPerSample = kBitsPerSample / 8;
144 constexpr size_t kHeaderSize = sizeof(WavHeader);
145
146 // If num_frames is zero, audio can be nullptr.
147 if (audio == nullptr && num_frames > 0) {
148 return errors::InvalidArgument("audio is null");
149 }
150 if (wav_string == nullptr) {
151 return errors::InvalidArgument("wav_string is null");
152 }
153 if (sample_rate == 0 || sample_rate > kuint32max) {
154 return errors::InvalidArgument("sample_rate must be in (0, 2^32), got: ",
155 sample_rate);
156 }
157 if (num_channels == 0 || num_channels > kuint16max) {
158 return errors::InvalidArgument("num_channels must be in (0, 2^16), got: ",
159 num_channels);
160 }
161
162 const size_t bytes_per_second = sample_rate * kBytesPerSample * num_channels;
163 const size_t num_samples = num_frames * num_channels;
164 const size_t data_size = num_samples * kBytesPerSample;
165 const size_t file_size = kHeaderSize + num_samples * kBytesPerSample;
166 const size_t bytes_per_frame = kBytesPerSample * num_channels;
167
168 // WAV represents the length of the file as a uint32 so file_size cannot
169 // exceed kuint32max.
170 if (file_size > kuint32max) {
171 return errors::InvalidArgument(
172 "Provided channels and frames cannot be encoded as a WAV.");
173 }
174
175 wav_string->resize(file_size);
176 char* data = &(*wav_string)[0];
177 WavHeader* header = absl::bit_cast<WavHeader*>(data);
178
179 // Fill RIFF chunk.
180 auto* riff_chunk = &header->riff_chunk;
181 memcpy(riff_chunk->chunk_id, kRiffChunkId, 4);
182 core::EncodeFixed32(riff_chunk->chunk_data_size, file_size - 8);
183 memcpy(riff_chunk->riff_type, kRiffType, 4);
184
185 // Fill format chunk.
186 auto* format_chunk = &header->format_chunk;
187 memcpy(format_chunk->chunk_id, kFormatChunkId, 4);
188 core::EncodeFixed32(format_chunk->chunk_data_size, kFormatChunkSize);
189 core::EncodeFixed16(format_chunk->compression_code, kCompressionCodePcm);
190 core::EncodeFixed16(format_chunk->channel_numbers, num_channels);
191 core::EncodeFixed32(format_chunk->sample_rate, sample_rate);
192 core::EncodeFixed32(format_chunk->bytes_per_second, bytes_per_second);
193 core::EncodeFixed16(format_chunk->bytes_per_frame, bytes_per_frame);
194 core::EncodeFixed16(format_chunk->bits_per_sample, kBitsPerSample);
195
196 // Fill data chunk.
197 auto* data_chunk = &header->data_chunk;
198 memcpy(data_chunk->chunk_id, kDataChunkId, 4);
199 core::EncodeFixed32(data_chunk->chunk_data_size, data_size);
200
201 // Write the audio.
202 data += kHeaderSize;
203 for (size_t i = 0; i < num_samples; ++i) {
204 int16_t sample = FloatToInt16Sample(audio[i]);
205 core::EncodeFixed16(&data[i * kBytesPerSample],
206 static_cast<uint16>(sample));
207 }
208 return Status::OK();
209 }
210
211 template Status EncodeAudioAsS16LEWav<std::string>(const float* audio,
212 size_t sample_rate,
213 size_t num_channels,
214 size_t num_frames,
215 std::string* wav_string);
216 template Status EncodeAudioAsS16LEWav<tstring>(const float* audio,
217 size_t sample_rate,
218 size_t num_channels,
219 size_t num_frames,
220 tstring* wav_string);
221
DecodeLin16WaveAsFloatVector(const std::string & wav_string,std::vector<float> * float_values,uint32 * sample_count,uint16 * channel_count,uint32 * sample_rate)222 Status DecodeLin16WaveAsFloatVector(const std::string& wav_string,
223 std::vector<float>* float_values,
224 uint32* sample_count, uint16* channel_count,
225 uint32* sample_rate) {
226 int offset = 0;
227 TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
228 uint32 total_file_size;
229 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset));
230 TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
231 TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset));
232 uint32 format_chunk_size;
233 TF_RETURN_IF_ERROR(
234 ReadValue<uint32>(wav_string, &format_chunk_size, &offset));
235 if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
236 return errors::InvalidArgument(
237 "Bad format chunk size for WAV: Expected 16 or 18, but got",
238 format_chunk_size);
239 }
240 uint16 audio_format;
241 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset));
242 if (audio_format != 1) {
243 return errors::InvalidArgument(
244 "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
245 }
246 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
247 if (*channel_count < 1) {
248 return errors::InvalidArgument(
249 "Bad number of channels for WAV: Expected at least 1, but got ",
250 *channel_count);
251 }
252 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
253 uint32 bytes_per_second;
254 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
255 uint16 bytes_per_sample;
256 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset));
257 // Confusingly, bits per sample is defined as holding the number of bits for
258 // one channel, unlike the definition of sample used elsewhere in the WAV
259 // spec. For example, bytes per sample is the memory needed for all channels
260 // for one point in time.
261 uint16 bits_per_sample;
262 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset));
263 if (bits_per_sample != 16) {
264 return errors::InvalidArgument(
265 "Can only read 16-bit WAV files, but received ", bits_per_sample);
266 }
267 const uint32 expected_bytes_per_sample =
268 ((bits_per_sample * *channel_count) + 7) / 8;
269 if (bytes_per_sample != expected_bytes_per_sample) {
270 return errors::InvalidArgument(
271 "Bad bytes per sample in WAV header: Expected ",
272 expected_bytes_per_sample, " but got ", bytes_per_sample);
273 }
274 const uint32 expected_bytes_per_second = bytes_per_sample * *sample_rate;
275 if (bytes_per_second != expected_bytes_per_second) {
276 return errors::InvalidArgument(
277 "Bad bytes per second in WAV header: Expected ",
278 expected_bytes_per_second, " but got ", bytes_per_second,
279 " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample,
280 ")");
281 }
282 if (format_chunk_size == 18) {
283 // Skip over this unused section.
284 offset += 2;
285 }
286
287 bool was_data_found = false;
288 while (offset < wav_string.size()) {
289 std::string chunk_id;
290 TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
291 uint32 chunk_size;
292 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
293 if (chunk_size > std::numeric_limits<int32>::max()) {
294 return errors::InvalidArgument(
295 "WAV data chunk '", chunk_id, "' is too large: ", chunk_size,
296 " bytes, but the limit is ", std::numeric_limits<int32>::max());
297 }
298 if (chunk_id == kDataChunkId) {
299 if (was_data_found) {
300 return errors::InvalidArgument("More than one data chunk found in WAV");
301 }
302 was_data_found = true;
303 *sample_count = chunk_size / bytes_per_sample;
304 const uint32 data_count = *sample_count * *channel_count;
305 int unused_new_offset = 0;
306 // Validate that the data exists before allocating space for it
307 // (prevent easy OOM errors).
308 TF_RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16) * data_count,
309 wav_string.size(),
310 &unused_new_offset));
311 float_values->resize(data_count);
312 for (int i = 0; i < data_count; ++i) {
313 int16_t single_channel_value = 0;
314 TF_RETURN_IF_ERROR(
315 ReadValue<int16>(wav_string, &single_channel_value, &offset));
316 (*float_values)[i] = Int16SampleToFloat(single_channel_value);
317 }
318 } else {
319 offset += chunk_size;
320 }
321 }
322 if (!was_data_found) {
323 return errors::InvalidArgument("No data chunk found in WAV");
324 }
325 return Status::OK();
326 }
327
328 } // namespace wav
329 } // namespace tensorflow
330