1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "cow_decompress.h"
18
19 #include <utility>
20
21 #include <android-base/logging.h>
22 #include <brotli/decode.h>
23 #include <zlib.h>
24
25 namespace android {
26 namespace snapshot {
27
28 class NoDecompressor final : public IDecompressor {
29 public:
30 bool Decompress(size_t) override;
31 };
32
Decompress(size_t)33 bool NoDecompressor::Decompress(size_t) {
34 size_t stream_remaining = stream_->Size();
35 while (stream_remaining) {
36 size_t buffer_size = stream_remaining;
37 uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size));
38 if (!buffer) {
39 LOG(ERROR) << "Could not acquire buffer from sink";
40 return false;
41 }
42
43 // Read until we can fill the buffer.
44 uint8_t* buffer_pos = buffer;
45 size_t bytes_to_read = std::min(buffer_size, stream_remaining);
46 while (bytes_to_read) {
47 size_t read;
48 if (!stream_->Read(buffer_pos, bytes_to_read, &read)) {
49 return false;
50 }
51 if (!read) {
52 LOG(ERROR) << "Stream ended prematurely";
53 return false;
54 }
55 if (!sink_->ReturnData(buffer_pos, read)) {
56 LOG(ERROR) << "Could not return buffer to sink";
57 return false;
58 }
59 buffer_pos += read;
60 bytes_to_read -= read;
61 stream_remaining -= read;
62 }
63 }
64 return true;
65 }
66
Uncompressed()67 std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() {
68 return std::unique_ptr<IDecompressor>(new NoDecompressor());
69 }
70
71 // Read chunks of the COW and incrementally stream them to the decoder.
72 class StreamDecompressor : public IDecompressor {
73 public:
74 bool Decompress(size_t output_bytes) override;
75
76 virtual bool Init() = 0;
77 virtual bool DecompressInput(const uint8_t* data, size_t length) = 0;
78 virtual bool Done() = 0;
79
80 protected:
81 bool GetFreshBuffer();
82
83 size_t output_bytes_;
84 size_t stream_remaining_;
85 uint8_t* output_buffer_ = nullptr;
86 size_t output_buffer_remaining_ = 0;
87 };
88
89 static constexpr size_t kChunkSize = 4096;
90
Decompress(size_t output_bytes)91 bool StreamDecompressor::Decompress(size_t output_bytes) {
92 if (!Init()) {
93 return false;
94 }
95
96 stream_remaining_ = stream_->Size();
97 output_bytes_ = output_bytes;
98
99 uint8_t chunk[kChunkSize];
100 while (stream_remaining_) {
101 size_t read = std::min(stream_remaining_, sizeof(chunk));
102 if (!stream_->Read(chunk, read, &read)) {
103 return false;
104 }
105 if (!read) {
106 LOG(ERROR) << "Stream ended prematurely";
107 return false;
108 }
109 if (!DecompressInput(chunk, read)) {
110 return false;
111 }
112
113 stream_remaining_ -= read;
114
115 if (stream_remaining_ && Done()) {
116 LOG(ERROR) << "Decompressor terminated early";
117 return false;
118 }
119 }
120 if (!Done()) {
121 LOG(ERROR) << "Decompressor expected more bytes";
122 return false;
123 }
124 return true;
125 }
126
GetFreshBuffer()127 bool StreamDecompressor::GetFreshBuffer() {
128 size_t request_size = std::min(output_bytes_, kChunkSize);
129 output_buffer_ =
130 reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_));
131 if (!output_buffer_) {
132 LOG(ERROR) << "Could not acquire buffer from sink";
133 return false;
134 }
135 return true;
136 }
137
138 class GzDecompressor final : public StreamDecompressor {
139 public:
140 ~GzDecompressor();
141
142 bool Init() override;
143 bool DecompressInput(const uint8_t* data, size_t length) override;
Done()144 bool Done() override { return ended_; }
145
146 private:
147 z_stream z_ = {};
148 bool ended_ = false;
149 };
150
Init()151 bool GzDecompressor::Init() {
152 if (int rv = inflateInit(&z_); rv != Z_OK) {
153 LOG(ERROR) << "inflateInit returned error code " << rv;
154 return false;
155 }
156 return true;
157 }
158
~GzDecompressor()159 GzDecompressor::~GzDecompressor() {
160 inflateEnd(&z_);
161 }
162
DecompressInput(const uint8_t * data,size_t length)163 bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) {
164 z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
165 z_.avail_in = length;
166
167 while (z_.avail_in) {
168 // If no more output buffer, grab a new buffer.
169 if (z_.avail_out == 0) {
170 if (!GetFreshBuffer()) {
171 return false;
172 }
173 z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
174 z_.avail_out = output_buffer_remaining_;
175 }
176
177 // Remember the position of the output buffer so we can call ReturnData.
178 auto avail_out = z_.avail_out;
179
180 // Decompress.
181 int rv = inflate(&z_, Z_NO_FLUSH);
182 if (rv != Z_OK && rv != Z_STREAM_END) {
183 LOG(ERROR) << "inflate returned error code " << rv;
184 return false;
185 }
186
187 size_t returned = avail_out - z_.avail_out;
188 if (!sink_->ReturnData(output_buffer_, returned)) {
189 LOG(ERROR) << "Could not return buffer to sink";
190 return false;
191 }
192 output_buffer_ += returned;
193 output_buffer_remaining_ -= returned;
194
195 if (rv == Z_STREAM_END) {
196 if (z_.avail_in) {
197 LOG(ERROR) << "Gz stream ended prematurely";
198 return false;
199 }
200 ended_ = true;
201 return true;
202 }
203 }
204 return true;
205 }
206
Gz()207 std::unique_ptr<IDecompressor> IDecompressor::Gz() {
208 return std::unique_ptr<IDecompressor>(new GzDecompressor());
209 }
210
211 class BrotliDecompressor final : public StreamDecompressor {
212 public:
213 ~BrotliDecompressor();
214
215 bool Init() override;
216 bool DecompressInput(const uint8_t* data, size_t length) override;
Done()217 bool Done() override { return BrotliDecoderIsFinished(decoder_); }
218
219 private:
220 BrotliDecoderState* decoder_ = nullptr;
221 };
222
Init()223 bool BrotliDecompressor::Init() {
224 decoder_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
225 return true;
226 }
227
~BrotliDecompressor()228 BrotliDecompressor::~BrotliDecompressor() {
229 if (decoder_) {
230 BrotliDecoderDestroyInstance(decoder_);
231 }
232 }
233
DecompressInput(const uint8_t * data,size_t length)234 bool BrotliDecompressor::DecompressInput(const uint8_t* data, size_t length) {
235 size_t available_in = length;
236 const uint8_t* next_in = data;
237
238 bool needs_more_output = false;
239 while (available_in || needs_more_output) {
240 if (!output_buffer_remaining_ && !GetFreshBuffer()) {
241 return false;
242 }
243
244 auto output_buffer = output_buffer_;
245 auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in,
246 &output_buffer_remaining_, &output_buffer_, nullptr);
247 if (r == BROTLI_DECODER_RESULT_ERROR) {
248 LOG(ERROR) << "brotli decode failed";
249 return false;
250 }
251 if (!sink_->ReturnData(output_buffer, output_buffer_ - output_buffer)) {
252 return false;
253 }
254 needs_more_output = (r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT);
255 }
256 return true;
257 }
258
Brotli()259 std::unique_ptr<IDecompressor> IDecompressor::Brotli() {
260 return std::unique_ptr<IDecompressor>(new BrotliDecompressor());
261 }
262
263 } // namespace snapshot
264 } // namespace android
265