1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "cow_decompress.h"
18
19 #include <array>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24
25 #include <android-base/logging.h>
26 #include <brotli/decode.h>
27 #include <lz4.h>
28 #include <zlib.h>
29 #include <zstd.h>
30
31 namespace android {
32 namespace snapshot {
33
ReadFully(void * buffer,size_t buffer_size)34 ssize_t IByteStream::ReadFully(void* buffer, size_t buffer_size) {
35 size_t stream_remaining = Size();
36
37 char* buffer_start = reinterpret_cast<char*>(buffer);
38 char* buffer_pos = buffer_start;
39 size_t buffer_remaining = buffer_size;
40 while (stream_remaining) {
41 const size_t to_read = std::min(buffer_remaining, stream_remaining);
42 const ssize_t actual_read = Read(buffer_pos, to_read);
43 if (actual_read < 0) {
44 return -1;
45 }
46 if (!actual_read) {
47 LOG(ERROR) << "Stream ended prematurely";
48 return -1;
49 }
50 CHECK_LE(actual_read, to_read);
51
52 stream_remaining -= actual_read;
53 buffer_pos += actual_read;
54 buffer_remaining -= actual_read;
55 }
56 return buffer_pos - buffer_start;
57 }
58
FromString(std::string_view compressor)59 std::unique_ptr<IDecompressor> IDecompressor::FromString(std::string_view compressor) {
60 if (compressor == "lz4") {
61 return IDecompressor::Lz4();
62 } else if (compressor == "brotli") {
63 return IDecompressor::Brotli();
64 } else if (compressor == "gz") {
65 return IDecompressor::Gz();
66 } else if (compressor == "zstd") {
67 return IDecompressor::Zstd();
68 } else {
69 return nullptr;
70 }
71 }
72
73 // Read chunks of the COW and incrementally stream them to the decoder.
74 class StreamDecompressor : public IDecompressor {
75 public:
76 ssize_t Decompress(void* buffer, size_t buffer_size, size_t decompressed_size,
77 size_t ignore_bytes) override;
78
79 virtual bool Init() = 0;
80 virtual bool PartialDecompress(const uint8_t* data, size_t length) = 0;
OutputFull() const81 bool OutputFull() const { return !ignore_bytes_ && !output_buffer_remaining_; }
82
83 protected:
84 size_t stream_remaining_;
85 uint8_t* output_buffer_ = nullptr;
86 size_t output_buffer_remaining_ = 0;
87 size_t ignore_bytes_ = 0;
88 bool decompressor_ended_ = false;
89 };
90
91 static constexpr size_t kChunkSize = 4096;
92
Decompress(void * buffer,size_t buffer_size,size_t,size_t ignore_bytes)93 ssize_t StreamDecompressor::Decompress(void* buffer, size_t buffer_size, size_t,
94 size_t ignore_bytes) {
95 if (!Init()) {
96 return false;
97 }
98
99 stream_remaining_ = stream_->Size();
100 output_buffer_ = reinterpret_cast<uint8_t*>(buffer);
101 output_buffer_remaining_ = buffer_size;
102 ignore_bytes_ = ignore_bytes;
103
104 uint8_t chunk[kChunkSize];
105 while (stream_remaining_ && output_buffer_remaining_ && !decompressor_ended_) {
106 size_t max_read = std::min(stream_remaining_, sizeof(chunk));
107 ssize_t read = stream_->Read(chunk, max_read);
108 if (read < 0) {
109 return -1;
110 }
111 if (!read) {
112 LOG(ERROR) << "Stream ended prematurely";
113 return -1;
114 }
115 if (!PartialDecompress(chunk, read)) {
116 return -1;
117 }
118 stream_remaining_ -= read;
119 }
120
121 if (stream_remaining_) {
122 if (decompressor_ended_ && !OutputFull()) {
123 // If there's more input in the stream, but we haven't finished
124 // consuming ignored bytes or available output space yet, then
125 // something weird happened. Report it and fail.
126 LOG(ERROR) << "Decompressor terminated early";
127 return -1;
128 }
129 } else {
130 if (!decompressor_ended_ && !OutputFull()) {
131 // The stream ended, but the decoder doesn't think so, and there are
132 // more bytes in the output buffer.
133 LOG(ERROR) << "Decompressor expected more bytes";
134 return -1;
135 }
136 }
137 return buffer_size - output_buffer_remaining_;
138 }
139
140 class GzDecompressor final : public StreamDecompressor {
141 public:
142 ~GzDecompressor();
143
144 bool Init() override;
145 bool PartialDecompress(const uint8_t* data, size_t length) override;
146
147 private:
148 z_stream z_ = {};
149 };
150
Init()151 bool GzDecompressor::Init() {
152 if (int rv = inflateInit(&z_); rv != Z_OK) {
153 LOG(ERROR) << "inflateInit returned error code " << rv;
154 return false;
155 }
156 return true;
157 }
158
~GzDecompressor()159 GzDecompressor::~GzDecompressor() {
160 inflateEnd(&z_);
161 }
162
PartialDecompress(const uint8_t * data,size_t length)163 bool GzDecompressor::PartialDecompress(const uint8_t* data, size_t length) {
164 z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
165 z_.avail_in = length;
166
167 // If we're asked to ignore starting bytes, we sink those into the output
168 // repeatedly until there is nothing left to ignore.
169 while (ignore_bytes_ && z_.avail_in) {
170 std::array<Bytef, kChunkSize> ignore_buffer;
171 size_t max_ignore = std::min(ignore_bytes_, ignore_buffer.size());
172 z_.next_out = ignore_buffer.data();
173 z_.avail_out = max_ignore;
174
175 int rv = inflate(&z_, Z_NO_FLUSH);
176 if (rv != Z_OK && rv != Z_STREAM_END) {
177 LOG(ERROR) << "inflate returned error code " << rv;
178 return false;
179 }
180
181 size_t returned = max_ignore - z_.avail_out;
182 CHECK_LE(returned, ignore_bytes_);
183
184 ignore_bytes_ -= returned;
185
186 if (rv == Z_STREAM_END) {
187 decompressor_ended_ = true;
188 return true;
189 }
190 }
191
192 z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
193 z_.avail_out = output_buffer_remaining_;
194
195 while (z_.avail_in && z_.avail_out) {
196 // Decompress.
197 int rv = inflate(&z_, Z_NO_FLUSH);
198 if (rv != Z_OK && rv != Z_STREAM_END) {
199 LOG(ERROR) << "inflate returned error code " << rv;
200 return false;
201 }
202
203 size_t returned = output_buffer_remaining_ - z_.avail_out;
204 CHECK_LE(returned, output_buffer_remaining_);
205
206 output_buffer_ += returned;
207 output_buffer_remaining_ -= returned;
208
209 if (rv == Z_STREAM_END) {
210 decompressor_ended_ = true;
211 return true;
212 }
213 }
214 return true;
215 }
216
217 class BrotliDecompressor final : public StreamDecompressor {
218 public:
219 ~BrotliDecompressor();
220
221 bool Init() override;
222 bool PartialDecompress(const uint8_t* data, size_t length) override;
223
224 private:
225 BrotliDecoderState* decoder_ = nullptr;
226 };
227
Init()228 bool BrotliDecompressor::Init() {
229 decoder_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
230 return true;
231 }
232
~BrotliDecompressor()233 BrotliDecompressor::~BrotliDecompressor() {
234 if (decoder_) {
235 BrotliDecoderDestroyInstance(decoder_);
236 }
237 }
238
PartialDecompress(const uint8_t * data,size_t length)239 bool BrotliDecompressor::PartialDecompress(const uint8_t* data, size_t length) {
240 size_t available_in = length;
241 const uint8_t* next_in = data;
242
243 while (available_in && ignore_bytes_ && !BrotliDecoderIsFinished(decoder_)) {
244 std::array<uint8_t, kChunkSize> ignore_buffer;
245 size_t max_ignore = std::min(ignore_bytes_, ignore_buffer.size());
246 size_t ignore_size = max_ignore;
247
248 uint8_t* ignore_buffer_ptr = ignore_buffer.data();
249 auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in, &ignore_size,
250 &ignore_buffer_ptr, nullptr);
251 if (r == BROTLI_DECODER_RESULT_ERROR) {
252 LOG(ERROR) << "brotli decode failed";
253 return false;
254 } else if (r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && available_in) {
255 LOG(ERROR) << "brotli unexpected needs more input";
256 return false;
257 }
258 ignore_bytes_ -= max_ignore - ignore_size;
259 }
260
261 while (available_in && !BrotliDecoderIsFinished(decoder_)) {
262 auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in,
263 &output_buffer_remaining_, &output_buffer_, nullptr);
264 if (r == BROTLI_DECODER_RESULT_ERROR) {
265 LOG(ERROR) << "brotli decode failed";
266 return false;
267 } else if (r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && available_in) {
268 LOG(ERROR) << "brotli unexpected needs more input";
269 return false;
270 }
271 }
272
273 decompressor_ended_ = BrotliDecoderIsFinished(decoder_);
274 return true;
275 }
276
277 class Lz4Decompressor final : public IDecompressor {
278 public:
279 ~Lz4Decompressor() override = default;
280
Decompress(void * buffer,size_t buffer_size,size_t decompressed_size,size_t ignore_bytes)281 ssize_t Decompress(void* buffer, size_t buffer_size, size_t decompressed_size,
282 size_t ignore_bytes) override {
283 std::string input_buffer(stream_->Size(), '\0');
284 ssize_t streamed_in = stream_->ReadFully(input_buffer.data(), input_buffer.size());
285 if (streamed_in < 0) {
286 return -1;
287 }
288 CHECK_EQ(streamed_in, stream_->Size());
289
290 char* decode_buffer = reinterpret_cast<char*>(buffer);
291 size_t decode_buffer_size = buffer_size;
292
293 // It's unclear if LZ4 can exactly satisfy a partial decode request, so
294 // if we get one, create a temporary buffer.
295 std::string temp;
296 if (buffer_size < decompressed_size) {
297 temp.resize(decompressed_size, '\0');
298 decode_buffer = temp.data();
299 decode_buffer_size = temp.size();
300 }
301
302 const int bytes_decompressed = LZ4_decompress_safe(input_buffer.data(), decode_buffer,
303 input_buffer.size(), decode_buffer_size);
304 if (bytes_decompressed < 0) {
305 LOG(ERROR) << "Failed to decompress LZ4 block, code: " << bytes_decompressed;
306 return -1;
307 }
308 if (bytes_decompressed != decompressed_size) {
309 LOG(ERROR) << "Failed to decompress LZ4 block, expected output size: "
310 << bytes_decompressed << ", actual: " << bytes_decompressed;
311 return -1;
312 }
313 CHECK_LE(bytes_decompressed, decode_buffer_size);
314
315 if (ignore_bytes > bytes_decompressed) {
316 LOG(ERROR) << "Ignoring more bytes than exist in stream (ignoring " << ignore_bytes
317 << ", got " << bytes_decompressed << ")";
318 return -1;
319 }
320
321 if (temp.empty()) {
322 // LZ4's API has no way to sink out the first N bytes of decoding,
323 // so we read them all in and memmove() to drop the partial read.
324 if (ignore_bytes) {
325 memmove(decode_buffer, decode_buffer + ignore_bytes,
326 bytes_decompressed - ignore_bytes);
327 }
328 return bytes_decompressed - ignore_bytes;
329 }
330
331 size_t max_copy = std::min(bytes_decompressed - ignore_bytes, buffer_size);
332 memcpy(buffer, temp.data() + ignore_bytes, max_copy);
333 return max_copy;
334 }
335 };
336
337 class ZstdDecompressor final : public IDecompressor {
338 public:
Decompress(void * buffer,size_t buffer_size,size_t decompressed_size,size_t ignore_bytes=0)339 ssize_t Decompress(void* buffer, size_t buffer_size, size_t decompressed_size,
340 size_t ignore_bytes = 0) override {
341 if (buffer_size < decompressed_size - ignore_bytes) {
342 LOG(INFO) << "buffer size " << buffer_size
343 << " is not large enough to hold decompressed data. Decompressed size "
344 << decompressed_size << ", ignore_bytes " << ignore_bytes;
345 return -1;
346 }
347 if (ignore_bytes == 0) {
348 if (!Decompress(buffer, decompressed_size)) {
349 return -1;
350 }
351 return decompressed_size;
352 }
353 std::vector<unsigned char> ignore_buf(decompressed_size);
354 if (!Decompress(ignore_buf.data(), decompressed_size)) {
355 return -1;
356 }
357 memcpy(buffer, ignore_buf.data() + ignore_bytes, buffer_size);
358 return decompressed_size;
359 }
Decompress(void * output_buffer,const size_t output_size)360 bool Decompress(void* output_buffer, const size_t output_size) {
361 std::string input_buffer;
362 input_buffer.resize(stream_->Size());
363 size_t bytes_read = stream_->Read(input_buffer.data(), input_buffer.size());
364 if (bytes_read != input_buffer.size()) {
365 LOG(ERROR) << "Failed to read all input at once. Expected: " << input_buffer.size()
366 << " actual: " << bytes_read;
367 return false;
368 }
369 const auto bytes_decompressed = ZSTD_decompress(output_buffer, output_size,
370 input_buffer.data(), input_buffer.size());
371 if (bytes_decompressed != output_size) {
372 LOG(ERROR) << "Failed to decompress ZSTD block, expected output size: " << output_size
373 << ", actual: " << bytes_decompressed;
374 return false;
375 }
376 return true;
377 }
378 };
379
Brotli()380 std::unique_ptr<IDecompressor> IDecompressor::Brotli() {
381 return std::make_unique<BrotliDecompressor>();
382 }
383
Gz()384 std::unique_ptr<IDecompressor> IDecompressor::Gz() {
385 return std::make_unique<GzDecompressor>();
386 }
387
Lz4()388 std::unique_ptr<IDecompressor> IDecompressor::Lz4() {
389 return std::make_unique<Lz4Decompressor>();
390 }
391
Zstd()392 std::unique_ptr<IDecompressor> IDecompressor::Zstd() {
393 return std::make_unique<ZstdDecompressor>();
394 }
395
396 } // namespace snapshot
397 } // namespace android
398