1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "cow_decompress.h"
18
19 #include <utility>
20
21 #include <android-base/logging.h>
22 #include <brotli/decode.h>
23 #include <lz4.h>
24 #include <zlib.h>
25
26 namespace android {
27 namespace snapshot {
28
29 class NoDecompressor final : public IDecompressor {
30 public:
31 bool Decompress(size_t) override;
32 };
33
Decompress(size_t)34 bool NoDecompressor::Decompress(size_t) {
35 size_t stream_remaining = stream_->Size();
36 while (stream_remaining) {
37 size_t buffer_size = stream_remaining;
38 uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size));
39 if (!buffer) {
40 LOG(ERROR) << "Could not acquire buffer from sink";
41 return false;
42 }
43
44 // Read until we can fill the buffer.
45 uint8_t* buffer_pos = buffer;
46 size_t bytes_to_read = std::min(buffer_size, stream_remaining);
47 while (bytes_to_read) {
48 size_t read;
49 if (!stream_->Read(buffer_pos, bytes_to_read, &read)) {
50 return false;
51 }
52 if (!read) {
53 LOG(ERROR) << "Stream ended prematurely";
54 return false;
55 }
56 if (!sink_->ReturnData(buffer_pos, read)) {
57 LOG(ERROR) << "Could not return buffer to sink";
58 return false;
59 }
60 buffer_pos += read;
61 bytes_to_read -= read;
62 stream_remaining -= read;
63 }
64 }
65 return true;
66 }
67
Uncompressed()68 std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() {
69 return std::unique_ptr<IDecompressor>(new NoDecompressor());
70 }
71
72 // Read chunks of the COW and incrementally stream them to the decoder.
73 class StreamDecompressor : public IDecompressor {
74 public:
75 bool Decompress(size_t output_bytes) override;
76
77 virtual bool Init() = 0;
78 virtual bool DecompressInput(const uint8_t* data, size_t length) = 0;
79 virtual bool Done() = 0;
80
81 protected:
82 bool GetFreshBuffer();
83
84 size_t output_bytes_;
85 size_t stream_remaining_;
86 uint8_t* output_buffer_ = nullptr;
87 size_t output_buffer_remaining_ = 0;
88 };
89
90 static constexpr size_t kChunkSize = 4096;
91
Decompress(size_t output_bytes)92 bool StreamDecompressor::Decompress(size_t output_bytes) {
93 if (!Init()) {
94 return false;
95 }
96
97 stream_remaining_ = stream_->Size();
98 output_bytes_ = output_bytes;
99
100 uint8_t chunk[kChunkSize];
101 while (stream_remaining_) {
102 size_t read = std::min(stream_remaining_, sizeof(chunk));
103 if (!stream_->Read(chunk, read, &read)) {
104 return false;
105 }
106 if (!read) {
107 LOG(ERROR) << "Stream ended prematurely";
108 return false;
109 }
110 if (!DecompressInput(chunk, read)) {
111 return false;
112 }
113
114 stream_remaining_ -= read;
115
116 if (stream_remaining_ && Done()) {
117 LOG(ERROR) << "Decompressor terminated early";
118 return false;
119 }
120 }
121 if (!Done()) {
122 LOG(ERROR) << "Decompressor expected more bytes";
123 return false;
124 }
125 return true;
126 }
127
GetFreshBuffer()128 bool StreamDecompressor::GetFreshBuffer() {
129 size_t request_size = std::min(output_bytes_, kChunkSize);
130 output_buffer_ =
131 reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_));
132 if (!output_buffer_) {
133 LOG(ERROR) << "Could not acquire buffer from sink";
134 return false;
135 }
136 return true;
137 }
138
139 class GzDecompressor final : public StreamDecompressor {
140 public:
141 ~GzDecompressor();
142
143 bool Init() override;
144 bool DecompressInput(const uint8_t* data, size_t length) override;
Done()145 bool Done() override { return ended_; }
146
147 private:
148 z_stream z_ = {};
149 bool ended_ = false;
150 };
151
Init()152 bool GzDecompressor::Init() {
153 if (int rv = inflateInit(&z_); rv != Z_OK) {
154 LOG(ERROR) << "inflateInit returned error code " << rv;
155 return false;
156 }
157 return true;
158 }
159
~GzDecompressor()160 GzDecompressor::~GzDecompressor() {
161 inflateEnd(&z_);
162 }
163
DecompressInput(const uint8_t * data,size_t length)164 bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) {
165 z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
166 z_.avail_in = length;
167
168 while (z_.avail_in) {
169 // If no more output buffer, grab a new buffer.
170 if (z_.avail_out == 0) {
171 if (!GetFreshBuffer()) {
172 return false;
173 }
174 z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
175 z_.avail_out = output_buffer_remaining_;
176 }
177
178 // Remember the position of the output buffer so we can call ReturnData.
179 auto avail_out = z_.avail_out;
180
181 // Decompress.
182 int rv = inflate(&z_, Z_NO_FLUSH);
183 if (rv != Z_OK && rv != Z_STREAM_END) {
184 LOG(ERROR) << "inflate returned error code " << rv;
185 return false;
186 }
187
188 size_t returned = avail_out - z_.avail_out;
189 if (!sink_->ReturnData(output_buffer_, returned)) {
190 LOG(ERROR) << "Could not return buffer to sink";
191 return false;
192 }
193 output_buffer_ += returned;
194 output_buffer_remaining_ -= returned;
195
196 if (rv == Z_STREAM_END) {
197 if (z_.avail_in) {
198 LOG(ERROR) << "Gz stream ended prematurely";
199 return false;
200 }
201 ended_ = true;
202 return true;
203 }
204 }
205 return true;
206 }
207
Gz()208 std::unique_ptr<IDecompressor> IDecompressor::Gz() {
209 return std::unique_ptr<IDecompressor>(new GzDecompressor());
210 }
211
212 class BrotliDecompressor final : public StreamDecompressor {
213 public:
214 ~BrotliDecompressor();
215
216 bool Init() override;
217 bool DecompressInput(const uint8_t* data, size_t length) override;
Done()218 bool Done() override { return BrotliDecoderIsFinished(decoder_); }
219
220 private:
221 BrotliDecoderState* decoder_ = nullptr;
222 };
223
Init()224 bool BrotliDecompressor::Init() {
225 decoder_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
226 return true;
227 }
228
~BrotliDecompressor()229 BrotliDecompressor::~BrotliDecompressor() {
230 if (decoder_) {
231 BrotliDecoderDestroyInstance(decoder_);
232 }
233 }
234
DecompressInput(const uint8_t * data,size_t length)235 bool BrotliDecompressor::DecompressInput(const uint8_t* data, size_t length) {
236 size_t available_in = length;
237 const uint8_t* next_in = data;
238
239 bool needs_more_output = false;
240 while (available_in || needs_more_output) {
241 if (!output_buffer_remaining_ && !GetFreshBuffer()) {
242 return false;
243 }
244
245 auto output_buffer = output_buffer_;
246 auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in,
247 &output_buffer_remaining_, &output_buffer_, nullptr);
248 if (r == BROTLI_DECODER_RESULT_ERROR) {
249 LOG(ERROR) << "brotli decode failed";
250 return false;
251 }
252 if (!sink_->ReturnData(output_buffer, output_buffer_ - output_buffer)) {
253 return false;
254 }
255 needs_more_output = (r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT);
256 }
257 return true;
258 }
259
Brotli()260 std::unique_ptr<IDecompressor> IDecompressor::Brotli() {
261 return std::unique_ptr<IDecompressor>(new BrotliDecompressor());
262 }
263
264 class Lz4Decompressor final : public IDecompressor {
265 public:
266 ~Lz4Decompressor() override = default;
267
Decompress(const size_t output_size)268 bool Decompress(const size_t output_size) override {
269 size_t actual_buffer_size = 0;
270 auto&& output_buffer = sink_->GetBuffer(output_size, &actual_buffer_size);
271 if (actual_buffer_size != output_size) {
272 LOG(ERROR) << "Failed to allocate buffer of size " << output_size << " only got "
273 << actual_buffer_size << " bytes";
274 return false;
275 }
276 // If input size is same as output size, then input is uncompressed.
277 if (stream_->Size() == output_size) {
278 size_t bytes_read = 0;
279 stream_->Read(output_buffer, output_size, &bytes_read);
280 if (bytes_read != output_size) {
281 LOG(ERROR) << "Failed to read all input at once. Expected: " << output_size
282 << " actual: " << bytes_read;
283 return false;
284 }
285 sink_->ReturnData(output_buffer, output_size);
286 return true;
287 }
288 std::string input_buffer;
289 input_buffer.resize(stream_->Size());
290 size_t bytes_read = 0;
291 stream_->Read(input_buffer.data(), input_buffer.size(), &bytes_read);
292 if (bytes_read != input_buffer.size()) {
293 LOG(ERROR) << "Failed to read all input at once. Expected: " << input_buffer.size()
294 << " actual: " << bytes_read;
295 return false;
296 }
297 const int bytes_decompressed =
298 LZ4_decompress_safe(input_buffer.data(), static_cast<char*>(output_buffer),
299 input_buffer.size(), output_size);
300 if (bytes_decompressed != output_size) {
301 LOG(ERROR) << "Failed to decompress LZ4 block, expected output size: " << output_size
302 << ", actual: " << bytes_decompressed;
303 return false;
304 }
305 sink_->ReturnData(output_buffer, output_size);
306 return true;
307 }
308 };
309
Lz4()310 std::unique_ptr<IDecompressor> IDecompressor::Lz4() {
311 return std::make_unique<Lz4Decompressor>();
312 }
313
314 } // namespace snapshot
315 } // namespace android
316