• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "cow_decompress.h"
18 
19 #include <utility>
20 
21 #include <android-base/logging.h>
22 #include <brotli/decode.h>
23 #include <lz4.h>
24 #include <zlib.h>
25 
26 namespace android {
27 namespace snapshot {
28 
29 class NoDecompressor final : public IDecompressor {
30   public:
31     bool Decompress(size_t) override;
32 };
33 
Decompress(size_t)34 bool NoDecompressor::Decompress(size_t) {
35     size_t stream_remaining = stream_->Size();
36     while (stream_remaining) {
37         size_t buffer_size = stream_remaining;
38         uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size));
39         if (!buffer) {
40             LOG(ERROR) << "Could not acquire buffer from sink";
41             return false;
42         }
43 
44         // Read until we can fill the buffer.
45         uint8_t* buffer_pos = buffer;
46         size_t bytes_to_read = std::min(buffer_size, stream_remaining);
47         while (bytes_to_read) {
48             size_t read;
49             if (!stream_->Read(buffer_pos, bytes_to_read, &read)) {
50                 return false;
51             }
52             if (!read) {
53                 LOG(ERROR) << "Stream ended prematurely";
54                 return false;
55             }
56             if (!sink_->ReturnData(buffer_pos, read)) {
57                 LOG(ERROR) << "Could not return buffer to sink";
58                 return false;
59             }
60             buffer_pos += read;
61             bytes_to_read -= read;
62             stream_remaining -= read;
63         }
64     }
65     return true;
66 }
67 
Uncompressed()68 std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() {
69     return std::unique_ptr<IDecompressor>(new NoDecompressor());
70 }
71 
72 // Read chunks of the COW and incrementally stream them to the decoder.
73 class StreamDecompressor : public IDecompressor {
74   public:
75     bool Decompress(size_t output_bytes) override;
76 
77     virtual bool Init() = 0;
78     virtual bool DecompressInput(const uint8_t* data, size_t length) = 0;
79     virtual bool Done() = 0;
80 
81   protected:
82     bool GetFreshBuffer();
83 
84     size_t output_bytes_;
85     size_t stream_remaining_;
86     uint8_t* output_buffer_ = nullptr;
87     size_t output_buffer_remaining_ = 0;
88 };
89 
90 static constexpr size_t kChunkSize = 4096;
91 
Decompress(size_t output_bytes)92 bool StreamDecompressor::Decompress(size_t output_bytes) {
93     if (!Init()) {
94         return false;
95     }
96 
97     stream_remaining_ = stream_->Size();
98     output_bytes_ = output_bytes;
99 
100     uint8_t chunk[kChunkSize];
101     while (stream_remaining_) {
102         size_t read = std::min(stream_remaining_, sizeof(chunk));
103         if (!stream_->Read(chunk, read, &read)) {
104             return false;
105         }
106         if (!read) {
107             LOG(ERROR) << "Stream ended prematurely";
108             return false;
109         }
110         if (!DecompressInput(chunk, read)) {
111             return false;
112         }
113 
114         stream_remaining_ -= read;
115 
116         if (stream_remaining_ && Done()) {
117             LOG(ERROR) << "Decompressor terminated early";
118             return false;
119         }
120     }
121     if (!Done()) {
122         LOG(ERROR) << "Decompressor expected more bytes";
123         return false;
124     }
125     return true;
126 }
127 
GetFreshBuffer()128 bool StreamDecompressor::GetFreshBuffer() {
129     size_t request_size = std::min(output_bytes_, kChunkSize);
130     output_buffer_ =
131             reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_));
132     if (!output_buffer_) {
133         LOG(ERROR) << "Could not acquire buffer from sink";
134         return false;
135     }
136     return true;
137 }
138 
139 class GzDecompressor final : public StreamDecompressor {
140   public:
141     ~GzDecompressor();
142 
143     bool Init() override;
144     bool DecompressInput(const uint8_t* data, size_t length) override;
Done()145     bool Done() override { return ended_; }
146 
147   private:
148     z_stream z_ = {};
149     bool ended_ = false;
150 };
151 
Init()152 bool GzDecompressor::Init() {
153     if (int rv = inflateInit(&z_); rv != Z_OK) {
154         LOG(ERROR) << "inflateInit returned error code " << rv;
155         return false;
156     }
157     return true;
158 }
159 
~GzDecompressor()160 GzDecompressor::~GzDecompressor() {
161     inflateEnd(&z_);
162 }
163 
DecompressInput(const uint8_t * data,size_t length)164 bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) {
165     z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
166     z_.avail_in = length;
167 
168     while (z_.avail_in) {
169         // If no more output buffer, grab a new buffer.
170         if (z_.avail_out == 0) {
171             if (!GetFreshBuffer()) {
172                 return false;
173             }
174             z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
175             z_.avail_out = output_buffer_remaining_;
176         }
177 
178         // Remember the position of the output buffer so we can call ReturnData.
179         auto avail_out = z_.avail_out;
180 
181         // Decompress.
182         int rv = inflate(&z_, Z_NO_FLUSH);
183         if (rv != Z_OK && rv != Z_STREAM_END) {
184             LOG(ERROR) << "inflate returned error code " << rv;
185             return false;
186         }
187 
188         size_t returned = avail_out - z_.avail_out;
189         if (!sink_->ReturnData(output_buffer_, returned)) {
190             LOG(ERROR) << "Could not return buffer to sink";
191             return false;
192         }
193         output_buffer_ += returned;
194         output_buffer_remaining_ -= returned;
195 
196         if (rv == Z_STREAM_END) {
197             if (z_.avail_in) {
198                 LOG(ERROR) << "Gz stream ended prematurely";
199                 return false;
200             }
201             ended_ = true;
202             return true;
203         }
204     }
205     return true;
206 }
207 
Gz()208 std::unique_ptr<IDecompressor> IDecompressor::Gz() {
209     return std::unique_ptr<IDecompressor>(new GzDecompressor());
210 }
211 
212 class BrotliDecompressor final : public StreamDecompressor {
213   public:
214     ~BrotliDecompressor();
215 
216     bool Init() override;
217     bool DecompressInput(const uint8_t* data, size_t length) override;
Done()218     bool Done() override { return BrotliDecoderIsFinished(decoder_); }
219 
220   private:
221     BrotliDecoderState* decoder_ = nullptr;
222 };
223 
Init()224 bool BrotliDecompressor::Init() {
225     decoder_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
226     return true;
227 }
228 
~BrotliDecompressor()229 BrotliDecompressor::~BrotliDecompressor() {
230     if (decoder_) {
231         BrotliDecoderDestroyInstance(decoder_);
232     }
233 }
234 
DecompressInput(const uint8_t * data,size_t length)235 bool BrotliDecompressor::DecompressInput(const uint8_t* data, size_t length) {
236     size_t available_in = length;
237     const uint8_t* next_in = data;
238 
239     bool needs_more_output = false;
240     while (available_in || needs_more_output) {
241         if (!output_buffer_remaining_ && !GetFreshBuffer()) {
242             return false;
243         }
244 
245         auto output_buffer = output_buffer_;
246         auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in,
247                                                &output_buffer_remaining_, &output_buffer_, nullptr);
248         if (r == BROTLI_DECODER_RESULT_ERROR) {
249             LOG(ERROR) << "brotli decode failed";
250             return false;
251         }
252         if (!sink_->ReturnData(output_buffer, output_buffer_ - output_buffer)) {
253             return false;
254         }
255         needs_more_output = (r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT);
256     }
257     return true;
258 }
259 
Brotli()260 std::unique_ptr<IDecompressor> IDecompressor::Brotli() {
261     return std::unique_ptr<IDecompressor>(new BrotliDecompressor());
262 }
263 
264 class Lz4Decompressor final : public IDecompressor {
265   public:
266     ~Lz4Decompressor() override = default;
267 
Decompress(const size_t output_size)268     bool Decompress(const size_t output_size) override {
269         size_t actual_buffer_size = 0;
270         auto&& output_buffer = sink_->GetBuffer(output_size, &actual_buffer_size);
271         if (actual_buffer_size != output_size) {
272             LOG(ERROR) << "Failed to allocate buffer of size " << output_size << " only got "
273                        << actual_buffer_size << " bytes";
274             return false;
275         }
276         // If input size is same as output size, then input is uncompressed.
277         if (stream_->Size() == output_size) {
278             size_t bytes_read = 0;
279             stream_->Read(output_buffer, output_size, &bytes_read);
280             if (bytes_read != output_size) {
281                 LOG(ERROR) << "Failed to read all input at once. Expected: " << output_size
282                            << " actual: " << bytes_read;
283                 return false;
284             }
285             sink_->ReturnData(output_buffer, output_size);
286             return true;
287         }
288         std::string input_buffer;
289         input_buffer.resize(stream_->Size());
290         size_t bytes_read = 0;
291         stream_->Read(input_buffer.data(), input_buffer.size(), &bytes_read);
292         if (bytes_read != input_buffer.size()) {
293             LOG(ERROR) << "Failed to read all input at once. Expected: " << input_buffer.size()
294                        << " actual: " << bytes_read;
295             return false;
296         }
297         const int bytes_decompressed =
298                 LZ4_decompress_safe(input_buffer.data(), static_cast<char*>(output_buffer),
299                                     input_buffer.size(), output_size);
300         if (bytes_decompressed != output_size) {
301             LOG(ERROR) << "Failed to decompress LZ4 block, expected output size: " << output_size
302                        << ", actual: " << bytes_decompressed;
303             return false;
304         }
305         sink_->ReturnData(output_buffer, output_size);
306         return true;
307     }
308 };
309 
Lz4()310 std::unique_ptr<IDecompressor> IDecompressor::Lz4() {
311     return std::make_unique<Lz4Decompressor>();
312 }
313 
314 }  // namespace snapshot
315 }  // namespace android
316