• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "cow_decompress.h"
18 
19 #include <array>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include <android-base/logging.h>
26 #include <brotli/decode.h>
27 #include <lz4.h>
28 #include <zlib.h>
29 #include <zstd.h>
30 
31 namespace android {
32 namespace snapshot {
33 
ReadFully(void * buffer,size_t buffer_size)34 ssize_t IByteStream::ReadFully(void* buffer, size_t buffer_size) {
35     size_t stream_remaining = Size();
36 
37     char* buffer_start = reinterpret_cast<char*>(buffer);
38     char* buffer_pos = buffer_start;
39     size_t buffer_remaining = buffer_size;
40     while (stream_remaining) {
41         const size_t to_read = std::min(buffer_remaining, stream_remaining);
42         const ssize_t actual_read = Read(buffer_pos, to_read);
43         if (actual_read < 0) {
44             return -1;
45         }
46         if (!actual_read) {
47             LOG(ERROR) << "Stream ended prematurely";
48             return -1;
49         }
50         CHECK_LE(actual_read, to_read);
51 
52         stream_remaining -= actual_read;
53         buffer_pos += actual_read;
54         buffer_remaining -= actual_read;
55     }
56     return buffer_pos - buffer_start;
57 }
58 
FromString(std::string_view compressor)59 std::unique_ptr<IDecompressor> IDecompressor::FromString(std::string_view compressor) {
60     if (compressor == "lz4") {
61         return IDecompressor::Lz4();
62     } else if (compressor == "brotli") {
63         return IDecompressor::Brotli();
64     } else if (compressor == "gz") {
65         return IDecompressor::Gz();
66     } else if (compressor == "zstd") {
67         return IDecompressor::Zstd();
68     } else {
69         return nullptr;
70     }
71 }
72 
73 // Read chunks of the COW and incrementally stream them to the decoder.
74 class StreamDecompressor : public IDecompressor {
75   public:
76     ssize_t Decompress(void* buffer, size_t buffer_size, size_t decompressed_size,
77                        size_t ignore_bytes) override;
78 
79     virtual bool Init() = 0;
80     virtual bool PartialDecompress(const uint8_t* data, size_t length) = 0;
OutputFull() const81     bool OutputFull() const { return !ignore_bytes_ && !output_buffer_remaining_; }
82 
83   protected:
84     size_t stream_remaining_;
85     uint8_t* output_buffer_ = nullptr;
86     size_t output_buffer_remaining_ = 0;
87     size_t ignore_bytes_ = 0;
88     bool decompressor_ended_ = false;
89 };
90 
91 static constexpr size_t kChunkSize = 4096;
92 
Decompress(void * buffer,size_t buffer_size,size_t,size_t ignore_bytes)93 ssize_t StreamDecompressor::Decompress(void* buffer, size_t buffer_size, size_t,
94                                        size_t ignore_bytes) {
95     if (!Init()) {
96         return false;
97     }
98 
99     stream_remaining_ = stream_->Size();
100     output_buffer_ = reinterpret_cast<uint8_t*>(buffer);
101     output_buffer_remaining_ = buffer_size;
102     ignore_bytes_ = ignore_bytes;
103 
104     uint8_t chunk[kChunkSize];
105     while (stream_remaining_ && output_buffer_remaining_ && !decompressor_ended_) {
106         size_t max_read = std::min(stream_remaining_, sizeof(chunk));
107         ssize_t read = stream_->Read(chunk, max_read);
108         if (read < 0) {
109             return -1;
110         }
111         if (!read) {
112             LOG(ERROR) << "Stream ended prematurely";
113             return -1;
114         }
115         if (!PartialDecompress(chunk, read)) {
116             return -1;
117         }
118         stream_remaining_ -= read;
119     }
120 
121     if (stream_remaining_) {
122         if (decompressor_ended_ && !OutputFull()) {
123             // If there's more input in the stream, but we haven't finished
124             // consuming ignored bytes or available output space yet, then
125             // something weird happened. Report it and fail.
126             LOG(ERROR) << "Decompressor terminated early";
127             return -1;
128         }
129     } else {
130         if (!decompressor_ended_ && !OutputFull()) {
131             // The stream ended, but the decoder doesn't think so, and there are
132             // more bytes in the output buffer.
133             LOG(ERROR) << "Decompressor expected more bytes";
134             return -1;
135         }
136     }
137     return buffer_size - output_buffer_remaining_;
138 }
139 
140 class GzDecompressor final : public StreamDecompressor {
141   public:
142     ~GzDecompressor();
143 
144     bool Init() override;
145     bool PartialDecompress(const uint8_t* data, size_t length) override;
146 
147   private:
148     z_stream z_ = {};
149 };
150 
Init()151 bool GzDecompressor::Init() {
152     if (int rv = inflateInit(&z_); rv != Z_OK) {
153         LOG(ERROR) << "inflateInit returned error code " << rv;
154         return false;
155     }
156     return true;
157 }
158 
~GzDecompressor()159 GzDecompressor::~GzDecompressor() {
160     inflateEnd(&z_);
161 }
162 
PartialDecompress(const uint8_t * data,size_t length)163 bool GzDecompressor::PartialDecompress(const uint8_t* data, size_t length) {
164     z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
165     z_.avail_in = length;
166 
167     // If we're asked to ignore starting bytes, we sink those into the output
168     // repeatedly until there is nothing left to ignore.
169     while (ignore_bytes_ && z_.avail_in) {
170         std::array<Bytef, kChunkSize> ignore_buffer;
171         size_t max_ignore = std::min(ignore_bytes_, ignore_buffer.size());
172         z_.next_out = ignore_buffer.data();
173         z_.avail_out = max_ignore;
174 
175         int rv = inflate(&z_, Z_NO_FLUSH);
176         if (rv != Z_OK && rv != Z_STREAM_END) {
177             LOG(ERROR) << "inflate returned error code " << rv;
178             return false;
179         }
180 
181         size_t returned = max_ignore - z_.avail_out;
182         CHECK_LE(returned, ignore_bytes_);
183 
184         ignore_bytes_ -= returned;
185 
186         if (rv == Z_STREAM_END) {
187             decompressor_ended_ = true;
188             return true;
189         }
190     }
191 
192     z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
193     z_.avail_out = output_buffer_remaining_;
194 
195     while (z_.avail_in && z_.avail_out) {
196         // Decompress.
197         int rv = inflate(&z_, Z_NO_FLUSH);
198         if (rv != Z_OK && rv != Z_STREAM_END) {
199             LOG(ERROR) << "inflate returned error code " << rv;
200             return false;
201         }
202 
203         size_t returned = output_buffer_remaining_ - z_.avail_out;
204         CHECK_LE(returned, output_buffer_remaining_);
205 
206         output_buffer_ += returned;
207         output_buffer_remaining_ -= returned;
208 
209         if (rv == Z_STREAM_END) {
210             decompressor_ended_ = true;
211             return true;
212         }
213     }
214     return true;
215 }
216 
217 class BrotliDecompressor final : public StreamDecompressor {
218   public:
219     ~BrotliDecompressor();
220 
221     bool Init() override;
222     bool PartialDecompress(const uint8_t* data, size_t length) override;
223 
224   private:
225     BrotliDecoderState* decoder_ = nullptr;
226 };
227 
Init()228 bool BrotliDecompressor::Init() {
229     decoder_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
230     return true;
231 }
232 
~BrotliDecompressor()233 BrotliDecompressor::~BrotliDecompressor() {
234     if (decoder_) {
235         BrotliDecoderDestroyInstance(decoder_);
236     }
237 }
238 
PartialDecompress(const uint8_t * data,size_t length)239 bool BrotliDecompressor::PartialDecompress(const uint8_t* data, size_t length) {
240     size_t available_in = length;
241     const uint8_t* next_in = data;
242 
243     while (available_in && ignore_bytes_ && !BrotliDecoderIsFinished(decoder_)) {
244         std::array<uint8_t, kChunkSize> ignore_buffer;
245         size_t max_ignore = std::min(ignore_bytes_, ignore_buffer.size());
246         size_t ignore_size = max_ignore;
247 
248         uint8_t* ignore_buffer_ptr = ignore_buffer.data();
249         auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in, &ignore_size,
250                                                &ignore_buffer_ptr, nullptr);
251         if (r == BROTLI_DECODER_RESULT_ERROR) {
252             LOG(ERROR) << "brotli decode failed";
253             return false;
254         } else if (r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && available_in) {
255             LOG(ERROR) << "brotli unexpected needs more input";
256             return false;
257         }
258         ignore_bytes_ -= max_ignore - ignore_size;
259     }
260 
261     while (available_in && !BrotliDecoderIsFinished(decoder_)) {
262         auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in,
263                                                &output_buffer_remaining_, &output_buffer_, nullptr);
264         if (r == BROTLI_DECODER_RESULT_ERROR) {
265             LOG(ERROR) << "brotli decode failed";
266             return false;
267         } else if (r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && available_in) {
268             LOG(ERROR) << "brotli unexpected needs more input";
269             return false;
270         }
271     }
272 
273     decompressor_ended_ = BrotliDecoderIsFinished(decoder_);
274     return true;
275 }
276 
277 class Lz4Decompressor final : public IDecompressor {
278   public:
279     ~Lz4Decompressor() override = default;
280 
Decompress(void * buffer,size_t buffer_size,size_t decompressed_size,size_t ignore_bytes)281     ssize_t Decompress(void* buffer, size_t buffer_size, size_t decompressed_size,
282                        size_t ignore_bytes) override {
283         std::string input_buffer(stream_->Size(), '\0');
284         ssize_t streamed_in = stream_->ReadFully(input_buffer.data(), input_buffer.size());
285         if (streamed_in < 0) {
286             return -1;
287         }
288         CHECK_EQ(streamed_in, stream_->Size());
289 
290         char* decode_buffer = reinterpret_cast<char*>(buffer);
291         size_t decode_buffer_size = buffer_size;
292 
293         // It's unclear if LZ4 can exactly satisfy a partial decode request, so
294         // if we get one, create a temporary buffer.
295         std::string temp;
296         if (buffer_size < decompressed_size) {
297             temp.resize(decompressed_size, '\0');
298             decode_buffer = temp.data();
299             decode_buffer_size = temp.size();
300         }
301 
302         const int bytes_decompressed = LZ4_decompress_safe(input_buffer.data(), decode_buffer,
303                                                            input_buffer.size(), decode_buffer_size);
304         if (bytes_decompressed < 0) {
305             LOG(ERROR) << "Failed to decompress LZ4 block, code: " << bytes_decompressed;
306             return -1;
307         }
308         if (bytes_decompressed != decompressed_size) {
309             LOG(ERROR) << "Failed to decompress LZ4 block, expected output size: "
310                        << bytes_decompressed << ", actual: " << bytes_decompressed;
311             return -1;
312         }
313         CHECK_LE(bytes_decompressed, decode_buffer_size);
314 
315         if (ignore_bytes > bytes_decompressed) {
316             LOG(ERROR) << "Ignoring more bytes than exist in stream (ignoring " << ignore_bytes
317                        << ", got " << bytes_decompressed << ")";
318             return -1;
319         }
320 
321         if (temp.empty()) {
322             // LZ4's API has no way to sink out the first N bytes of decoding,
323             // so we read them all in and memmove() to drop the partial read.
324             if (ignore_bytes) {
325                 memmove(decode_buffer, decode_buffer + ignore_bytes,
326                         bytes_decompressed - ignore_bytes);
327             }
328             return bytes_decompressed - ignore_bytes;
329         }
330 
331         size_t max_copy = std::min(bytes_decompressed - ignore_bytes, buffer_size);
332         memcpy(buffer, temp.data() + ignore_bytes, max_copy);
333         return max_copy;
334     }
335 };
336 
337 class ZstdDecompressor final : public IDecompressor {
338   public:
Decompress(void * buffer,size_t buffer_size,size_t decompressed_size,size_t ignore_bytes=0)339     ssize_t Decompress(void* buffer, size_t buffer_size, size_t decompressed_size,
340                        size_t ignore_bytes = 0) override {
341         if (buffer_size < decompressed_size - ignore_bytes) {
342             LOG(INFO) << "buffer size " << buffer_size
343                       << " is not large enough to hold decompressed data. Decompressed size "
344                       << decompressed_size << ", ignore_bytes " << ignore_bytes;
345             return -1;
346         }
347         if (ignore_bytes == 0) {
348             if (!Decompress(buffer, decompressed_size)) {
349                 return -1;
350             }
351             return decompressed_size;
352         }
353         std::vector<unsigned char> ignore_buf(decompressed_size);
354         if (!Decompress(ignore_buf.data(), decompressed_size)) {
355             return -1;
356         }
357         memcpy(buffer, ignore_buf.data() + ignore_bytes, buffer_size);
358         return decompressed_size;
359     }
Decompress(void * output_buffer,const size_t output_size)360     bool Decompress(void* output_buffer, const size_t output_size) {
361         std::string input_buffer;
362         input_buffer.resize(stream_->Size());
363         size_t bytes_read = stream_->Read(input_buffer.data(), input_buffer.size());
364         if (bytes_read != input_buffer.size()) {
365             LOG(ERROR) << "Failed to read all input at once. Expected: " << input_buffer.size()
366                        << " actual: " << bytes_read;
367             return false;
368         }
369         const auto bytes_decompressed = ZSTD_decompress(output_buffer, output_size,
370                                                         input_buffer.data(), input_buffer.size());
371         if (bytes_decompressed != output_size) {
372             LOG(ERROR) << "Failed to decompress ZSTD block, expected output size: " << output_size
373                        << ", actual: " << bytes_decompressed;
374             return false;
375         }
376         return true;
377     }
378 };
379 
Brotli()380 std::unique_ptr<IDecompressor> IDecompressor::Brotli() {
381     return std::make_unique<BrotliDecompressor>();
382 }
383 
Gz()384 std::unique_ptr<IDecompressor> IDecompressor::Gz() {
385     return std::make_unique<GzDecompressor>();
386 }
387 
Lz4()388 std::unique_ptr<IDecompressor> IDecompressor::Lz4() {
389     return std::make_unique<Lz4Decompressor>();
390 }
391 
Zstd()392 std::unique_ptr<IDecompressor> IDecompressor::Zstd() {
393     return std::make_unique<ZstdDecompressor>();
394 }
395 
396 }  // namespace snapshot
397 }  // namespace android
398