1 /*
2 * Copyright (c) 2016-2020, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
8 * You may select, at your option, one of the above-listed licenses.
9 */
10
11 #include "method.h"
12
13 #include <stdio.h>
14 #include <stdlib.h>
15
16 #define ZSTD_STATIC_LINKING_ONLY
17 #include <zstd.h>
18
19 #define MIN(x, y) ((x) < (y) ? (x) : (y))
20
21 static char const* g_zstdcli = NULL;
22
method_set_zstdcli(char const * zstdcli)23 void method_set_zstdcli(char const* zstdcli) {
24 g_zstdcli = zstdcli;
25 }
26
27 /**
28 * Macro to get a pointer of type, given ptr, which is a member variable with
29 * the given name, member.
30 *
31 * method_state_t* base = ...;
32 * buffer_state_t* state = container_of(base, buffer_state_t, base);
33 */
34 #define container_of(ptr, type, member) \
35 ((type*)(ptr == NULL ? NULL : (char*)(ptr)-offsetof(type, member)))
36
37 /** State to reuse the same buffers between compression calls. */
38 typedef struct {
39 method_state_t base;
40 data_buffers_t inputs; /**< The input buffer for each file. */
41 data_buffer_t dictionary; /**< The dictionary. */
42 data_buffer_t compressed; /**< The compressed data buffer. */
43 data_buffer_t decompressed; /**< The decompressed data buffer. */
44 } buffer_state_t;
45
buffers_max_size(data_buffers_t buffers)46 static size_t buffers_max_size(data_buffers_t buffers) {
47 size_t max = 0;
48 for (size_t i = 0; i < buffers.size; ++i) {
49 if (buffers.buffers[i].size > max)
50 max = buffers.buffers[i].size;
51 }
52 return max;
53 }
54
buffer_state_create(data_t const * data)55 static method_state_t* buffer_state_create(data_t const* data) {
56 buffer_state_t* state = (buffer_state_t*)calloc(1, sizeof(buffer_state_t));
57 if (state == NULL)
58 return NULL;
59 state->base.data = data;
60 state->inputs = data_buffers_get(data);
61 state->dictionary = data_buffer_get_dict(data);
62 size_t const max_size = buffers_max_size(state->inputs);
63 state->compressed = data_buffer_create(ZSTD_compressBound(max_size));
64 state->decompressed = data_buffer_create(max_size);
65 return &state->base;
66 }
67
buffer_state_destroy(method_state_t * base)68 static void buffer_state_destroy(method_state_t* base) {
69 if (base == NULL)
70 return;
71 buffer_state_t* state = container_of(base, buffer_state_t, base);
72 free(state);
73 }
74
buffer_state_bad(buffer_state_t const * state,config_t const * config)75 static int buffer_state_bad(
76 buffer_state_t const* state,
77 config_t const* config) {
78 if (state == NULL) {
79 fprintf(stderr, "buffer_state_t is NULL\n");
80 return 1;
81 }
82 if (state->inputs.size == 0 || state->compressed.data == NULL ||
83 state->decompressed.data == NULL) {
84 fprintf(stderr, "buffer state allocation failure\n");
85 return 1;
86 }
87 if (config->use_dictionary && state->dictionary.data == NULL) {
88 fprintf(stderr, "dictionary loading failed\n");
89 return 1;
90 }
91 return 0;
92 }
93
simple_compress(method_state_t * base,config_t const * config)94 static result_t simple_compress(method_state_t* base, config_t const* config) {
95 buffer_state_t* state = container_of(base, buffer_state_t, base);
96
97 if (buffer_state_bad(state, config))
98 return result_error(result_error_system_error);
99
100 /* Keep the tests short by skipping directories, since behavior shouldn't
101 * change.
102 */
103 if (base->data->type != data_type_file)
104 return result_error(result_error_skip);
105
106 if (config->use_dictionary || config->no_pledged_src_size)
107 return result_error(result_error_skip);
108
109 /* If the config doesn't specify a level, skip. */
110 int const level = config_get_level(config);
111 if (level == CONFIG_NO_LEVEL)
112 return result_error(result_error_skip);
113
114 data_buffer_t const input = state->inputs.buffers[0];
115
116 /* Compress, decompress, and check the result. */
117 state->compressed.size = ZSTD_compress(
118 state->compressed.data,
119 state->compressed.capacity,
120 input.data,
121 input.size,
122 level);
123 if (ZSTD_isError(state->compressed.size))
124 return result_error(result_error_compression_error);
125
126 state->decompressed.size = ZSTD_decompress(
127 state->decompressed.data,
128 state->decompressed.capacity,
129 state->compressed.data,
130 state->compressed.size);
131 if (ZSTD_isError(state->decompressed.size))
132 return result_error(result_error_decompression_error);
133 if (data_buffer_compare(input, state->decompressed))
134 return result_error(result_error_round_trip_error);
135
136 result_data_t data;
137 data.total_size = state->compressed.size;
138 return result_data(data);
139 }
140
compress_cctx_compress(method_state_t * base,config_t const * config)141 static result_t compress_cctx_compress(
142 method_state_t* base,
143 config_t const* config) {
144 buffer_state_t* state = container_of(base, buffer_state_t, base);
145
146 if (buffer_state_bad(state, config))
147 return result_error(result_error_system_error);
148
149 if (config->no_pledged_src_size)
150 return result_error(result_error_skip);
151
152 if (base->data->type != data_type_dir)
153 return result_error(result_error_skip);
154
155 int const level = config_get_level(config);
156
157 ZSTD_CCtx* cctx = ZSTD_createCCtx();
158 ZSTD_DCtx* dctx = ZSTD_createDCtx();
159 if (cctx == NULL || dctx == NULL) {
160 fprintf(stderr, "context creation failed\n");
161 return result_error(result_error_system_error);
162 }
163
164 result_t result;
165 result_data_t data = {.total_size = 0};
166 for (size_t i = 0; i < state->inputs.size; ++i) {
167 data_buffer_t const input = state->inputs.buffers[i];
168 ZSTD_parameters const params =
169 config_get_zstd_params(config, input.size, state->dictionary.size);
170
171 if (level == CONFIG_NO_LEVEL)
172 state->compressed.size = ZSTD_compress_advanced(
173 cctx,
174 state->compressed.data,
175 state->compressed.capacity,
176 input.data,
177 input.size,
178 config->use_dictionary ? state->dictionary.data : NULL,
179 config->use_dictionary ? state->dictionary.size : 0,
180 params);
181 else if (config->use_dictionary)
182 state->compressed.size = ZSTD_compress_usingDict(
183 cctx,
184 state->compressed.data,
185 state->compressed.capacity,
186 input.data,
187 input.size,
188 state->dictionary.data,
189 state->dictionary.size,
190 level);
191 else
192 state->compressed.size = ZSTD_compressCCtx(
193 cctx,
194 state->compressed.data,
195 state->compressed.capacity,
196 input.data,
197 input.size,
198 level);
199
200 if (ZSTD_isError(state->compressed.size)) {
201 result = result_error(result_error_compression_error);
202 goto out;
203 }
204
205 if (config->use_dictionary)
206 state->decompressed.size = ZSTD_decompress_usingDict(
207 dctx,
208 state->decompressed.data,
209 state->decompressed.capacity,
210 state->compressed.data,
211 state->compressed.size,
212 state->dictionary.data,
213 state->dictionary.size);
214 else
215 state->decompressed.size = ZSTD_decompressDCtx(
216 dctx,
217 state->decompressed.data,
218 state->decompressed.capacity,
219 state->compressed.data,
220 state->compressed.size);
221 if (ZSTD_isError(state->decompressed.size)) {
222 result = result_error(result_error_decompression_error);
223 goto out;
224 }
225 if (data_buffer_compare(input, state->decompressed)) {
226 result = result_error(result_error_round_trip_error);
227 goto out;
228 }
229
230 data.total_size += state->compressed.size;
231 }
232
233 result = result_data(data);
234 out:
235 ZSTD_freeCCtx(cctx);
236 ZSTD_freeDCtx(dctx);
237 return result;
238 }
239
240 /** Generic state creation function. */
method_state_create(data_t const * data)241 static method_state_t* method_state_create(data_t const* data) {
242 method_state_t* state = (method_state_t*)malloc(sizeof(method_state_t));
243 if (state == NULL)
244 return NULL;
245 state->data = data;
246 return state;
247 }
248
method_state_destroy(method_state_t * state)249 static void method_state_destroy(method_state_t* state) {
250 free(state);
251 }
252
cli_compress(method_state_t * state,config_t const * config)253 static result_t cli_compress(method_state_t* state, config_t const* config) {
254 if (config->cli_args == NULL)
255 return result_error(result_error_skip);
256
257 /* We don't support no pledged source size with directories. Too slow. */
258 if (state->data->type == data_type_dir && config->no_pledged_src_size)
259 return result_error(result_error_skip);
260
261 if (g_zstdcli == NULL)
262 return result_error(result_error_system_error);
263
264 /* '<zstd>' -cqr <args> [-D '<dict>'] '<file/dir>' */
265 char cmd[1024];
266 size_t const cmd_size = snprintf(
267 cmd,
268 sizeof(cmd),
269 "'%s' -cqr %s %s%s%s %s '%s'",
270 g_zstdcli,
271 config->cli_args,
272 config->use_dictionary ? "-D '" : "",
273 config->use_dictionary ? state->data->dict.path : "",
274 config->use_dictionary ? "'" : "",
275 config->no_pledged_src_size ? "<" : "",
276 state->data->data.path);
277 if (cmd_size >= sizeof(cmd)) {
278 fprintf(stderr, "command too large: %s\n", cmd);
279 return result_error(result_error_system_error);
280 }
281 FILE* zstd = popen(cmd, "r");
282 if (zstd == NULL) {
283 fprintf(stderr, "failed to popen command: %s\n", cmd);
284 return result_error(result_error_system_error);
285 }
286
287 char out[4096];
288 size_t total_size = 0;
289 while (1) {
290 size_t const size = fread(out, 1, sizeof(out), zstd);
291 total_size += size;
292 if (size != sizeof(out))
293 break;
294 }
295 if (ferror(zstd) || pclose(zstd) != 0) {
296 fprintf(stderr, "zstd failed with command: %s\n", cmd);
297 return result_error(result_error_compression_error);
298 }
299
300 result_data_t const data = {.total_size = total_size};
301 return result_data(data);
302 }
303
advanced_config(ZSTD_CCtx * cctx,buffer_state_t * state,config_t const * config)304 static int advanced_config(
305 ZSTD_CCtx* cctx,
306 buffer_state_t* state,
307 config_t const* config) {
308 ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
309 for (size_t p = 0; p < config->param_values.size; ++p) {
310 param_value_t const pv = config->param_values.data[p];
311 if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, pv.param, pv.value))) {
312 return 1;
313 }
314 }
315 if (config->use_dictionary) {
316 if (ZSTD_isError(ZSTD_CCtx_loadDictionary(
317 cctx, state->dictionary.data, state->dictionary.size))) {
318 return 1;
319 }
320 }
321 return 0;
322 }
323
advanced_one_pass_compress_output_adjustment(method_state_t * base,config_t const * config,size_t const subtract)324 static result_t advanced_one_pass_compress_output_adjustment(
325 method_state_t* base,
326 config_t const* config,
327 size_t const subtract) {
328 buffer_state_t* state = container_of(base, buffer_state_t, base);
329
330 if (buffer_state_bad(state, config))
331 return result_error(result_error_system_error);
332
333 ZSTD_CCtx* cctx = ZSTD_createCCtx();
334 result_t result;
335
336 if (!cctx || advanced_config(cctx, state, config)) {
337 result = result_error(result_error_compression_error);
338 goto out;
339 }
340
341 result_data_t data = {.total_size = 0};
342 for (size_t i = 0; i < state->inputs.size; ++i) {
343 data_buffer_t const input = state->inputs.buffers[i];
344
345 if (!config->no_pledged_src_size) {
346 if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) {
347 result = result_error(result_error_compression_error);
348 goto out;
349 }
350 }
351 size_t const size = ZSTD_compress2(
352 cctx,
353 state->compressed.data,
354 ZSTD_compressBound(input.size) - subtract,
355 input.data,
356 input.size);
357 if (ZSTD_isError(size)) {
358 result = result_error(result_error_compression_error);
359 goto out;
360 }
361 data.total_size += size;
362 }
363
364 result = result_data(data);
365 out:
366 ZSTD_freeCCtx(cctx);
367 return result;
368 }
369
advanced_one_pass_compress(method_state_t * base,config_t const * config)370 static result_t advanced_one_pass_compress(
371 method_state_t* base,
372 config_t const* config) {
373 return advanced_one_pass_compress_output_adjustment(base, config, 0);
374 }
375
advanced_one_pass_compress_small_output(method_state_t * base,config_t const * config)376 static result_t advanced_one_pass_compress_small_output(
377 method_state_t* base,
378 config_t const* config) {
379 return advanced_one_pass_compress_output_adjustment(base, config, 1);
380 }
381
advanced_streaming_compress(method_state_t * base,config_t const * config)382 static result_t advanced_streaming_compress(
383 method_state_t* base,
384 config_t const* config) {
385 buffer_state_t* state = container_of(base, buffer_state_t, base);
386
387 if (buffer_state_bad(state, config))
388 return result_error(result_error_system_error);
389
390 ZSTD_CCtx* cctx = ZSTD_createCCtx();
391 result_t result;
392
393 if (!cctx || advanced_config(cctx, state, config)) {
394 result = result_error(result_error_compression_error);
395 goto out;
396 }
397
398 result_data_t data = {.total_size = 0};
399 for (size_t i = 0; i < state->inputs.size; ++i) {
400 data_buffer_t input = state->inputs.buffers[i];
401
402 if (!config->no_pledged_src_size) {
403 if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) {
404 result = result_error(result_error_compression_error);
405 goto out;
406 }
407 }
408
409 while (input.size > 0) {
410 ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)};
411 input.data += in.size;
412 input.size -= in.size;
413 ZSTD_EndDirective const op =
414 input.size > 0 ? ZSTD_e_continue : ZSTD_e_end;
415 size_t ret = 0;
416 while (in.pos < in.size || (op == ZSTD_e_end && ret != 0)) {
417 ZSTD_outBuffer out = {state->compressed.data,
418 MIN(state->compressed.capacity, 1024)};
419 ret = ZSTD_compressStream2(cctx, &out, &in, op);
420 if (ZSTD_isError(ret)) {
421 result = result_error(result_error_compression_error);
422 goto out;
423 }
424 data.total_size += out.pos;
425 }
426 }
427 }
428
429 result = result_data(data);
430 out:
431 ZSTD_freeCCtx(cctx);
432 return result;
433 }
434
init_cstream(buffer_state_t * state,ZSTD_CStream * zcs,config_t const * config,int const advanced,ZSTD_CDict ** cdict)435 static int init_cstream(
436 buffer_state_t* state,
437 ZSTD_CStream* zcs,
438 config_t const* config,
439 int const advanced,
440 ZSTD_CDict** cdict)
441 {
442 size_t zret;
443 if (advanced) {
444 ZSTD_parameters const params = config_get_zstd_params(config, 0, 0);
445 ZSTD_CDict* dict = NULL;
446 if (cdict) {
447 if (!config->use_dictionary)
448 return 1;
449 *cdict = ZSTD_createCDict_advanced(
450 state->dictionary.data,
451 state->dictionary.size,
452 ZSTD_dlm_byRef,
453 ZSTD_dct_auto,
454 params.cParams,
455 ZSTD_defaultCMem);
456 if (!*cdict) {
457 return 1;
458 }
459 zret = ZSTD_initCStream_usingCDict_advanced(
460 zcs, *cdict, params.fParams, ZSTD_CONTENTSIZE_UNKNOWN);
461 } else {
462 zret = ZSTD_initCStream_advanced(
463 zcs,
464 config->use_dictionary ? state->dictionary.data : NULL,
465 config->use_dictionary ? state->dictionary.size : 0,
466 params,
467 ZSTD_CONTENTSIZE_UNKNOWN);
468 }
469 } else {
470 int const level = config_get_level(config);
471 if (level == CONFIG_NO_LEVEL)
472 return 1;
473 if (cdict) {
474 if (!config->use_dictionary)
475 return 1;
476 *cdict = ZSTD_createCDict(
477 state->dictionary.data,
478 state->dictionary.size,
479 level);
480 if (!*cdict) {
481 return 1;
482 }
483 zret = ZSTD_initCStream_usingCDict(zcs, *cdict);
484 } else if (config->use_dictionary) {
485 zret = ZSTD_initCStream_usingDict(
486 zcs,
487 state->dictionary.data,
488 state->dictionary.size,
489 level);
490 } else {
491 zret = ZSTD_initCStream(zcs, level);
492 }
493 }
494 if (ZSTD_isError(zret)) {
495 return 1;
496 }
497 return 0;
498 }
499
old_streaming_compress_internal(method_state_t * base,config_t const * config,int const advanced,int const cdict)500 static result_t old_streaming_compress_internal(
501 method_state_t* base,
502 config_t const* config,
503 int const advanced,
504 int const cdict) {
505 buffer_state_t* state = container_of(base, buffer_state_t, base);
506
507 if (buffer_state_bad(state, config))
508 return result_error(result_error_system_error);
509
510
511 ZSTD_CStream* zcs = ZSTD_createCStream();
512 ZSTD_CDict* cd = NULL;
513 result_t result;
514 if (zcs == NULL) {
515 result = result_error(result_error_compression_error);
516 goto out;
517 }
518 if (!advanced && config_get_level(config) == CONFIG_NO_LEVEL) {
519 result = result_error(result_error_skip);
520 goto out;
521 }
522 if (cdict && !config->use_dictionary) {
523 result = result_error(result_error_skip);
524 goto out;
525 }
526 if (init_cstream(state, zcs, config, advanced, cdict ? &cd : NULL)) {
527 result = result_error(result_error_compression_error);
528 goto out;
529 }
530
531 result_data_t data = {.total_size = 0};
532 for (size_t i = 0; i < state->inputs.size; ++i) {
533 data_buffer_t input = state->inputs.buffers[i];
534 size_t zret = ZSTD_resetCStream(
535 zcs,
536 config->no_pledged_src_size ? ZSTD_CONTENTSIZE_UNKNOWN : input.size);
537 if (ZSTD_isError(zret)) {
538 result = result_error(result_error_compression_error);
539 goto out;
540 }
541
542 while (input.size > 0) {
543 ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)};
544 input.data += in.size;
545 input.size -= in.size;
546 ZSTD_EndDirective const op =
547 input.size > 0 ? ZSTD_e_continue : ZSTD_e_end;
548 zret = 0;
549 while (in.pos < in.size || (op == ZSTD_e_end && zret != 0)) {
550 ZSTD_outBuffer out = {state->compressed.data,
551 MIN(state->compressed.capacity, 1024)};
552 if (op == ZSTD_e_continue || in.pos < in.size)
553 zret = ZSTD_compressStream(zcs, &out, &in);
554 else
555 zret = ZSTD_endStream(zcs, &out);
556 if (ZSTD_isError(zret)) {
557 result = result_error(result_error_compression_error);
558 goto out;
559 }
560 data.total_size += out.pos;
561 }
562 }
563 }
564
565 result = result_data(data);
566 out:
567 ZSTD_freeCStream(zcs);
568 ZSTD_freeCDict(cd);
569 return result;
570 }
571
old_streaming_compress(method_state_t * base,config_t const * config)572 static result_t old_streaming_compress(
573 method_state_t* base,
574 config_t const* config)
575 {
576 return old_streaming_compress_internal(
577 base, config, /* advanced */ 0, /* cdict */ 0);
578 }
579
old_streaming_compress_advanced(method_state_t * base,config_t const * config)580 static result_t old_streaming_compress_advanced(
581 method_state_t* base,
582 config_t const* config)
583 {
584 return old_streaming_compress_internal(
585 base, config, /* advanced */ 1, /* cdict */ 0);
586 }
587
old_streaming_compress_cdict(method_state_t * base,config_t const * config)588 static result_t old_streaming_compress_cdict(
589 method_state_t* base,
590 config_t const* config)
591 {
592 return old_streaming_compress_internal(
593 base, config, /* advanced */ 0, /* cdict */ 1);
594 }
595
old_streaming_compress_cdict_advanced(method_state_t * base,config_t const * config)596 static result_t old_streaming_compress_cdict_advanced(
597 method_state_t* base,
598 config_t const* config)
599 {
600 return old_streaming_compress_internal(
601 base, config, /* advanced */ 1, /* cdict */ 1);
602 }
603
604 method_t const simple = {
605 .name = "compress simple",
606 .create = buffer_state_create,
607 .compress = simple_compress,
608 .destroy = buffer_state_destroy,
609 };
610
611 method_t const compress_cctx = {
612 .name = "compress cctx",
613 .create = buffer_state_create,
614 .compress = compress_cctx_compress,
615 .destroy = buffer_state_destroy,
616 };
617
618 method_t const advanced_one_pass = {
619 .name = "advanced one pass",
620 .create = buffer_state_create,
621 .compress = advanced_one_pass_compress,
622 .destroy = buffer_state_destroy,
623 };
624
625 method_t const advanced_one_pass_small_out = {
626 .name = "advanced one pass small out",
627 .create = buffer_state_create,
628 .compress = advanced_one_pass_compress,
629 .destroy = buffer_state_destroy,
630 };
631
632 method_t const advanced_streaming = {
633 .name = "advanced streaming",
634 .create = buffer_state_create,
635 .compress = advanced_streaming_compress,
636 .destroy = buffer_state_destroy,
637 };
638
639 method_t const old_streaming = {
640 .name = "old streaming",
641 .create = buffer_state_create,
642 .compress = old_streaming_compress,
643 .destroy = buffer_state_destroy,
644 };
645
646 method_t const old_streaming_advanced = {
647 .name = "old streaming advanced",
648 .create = buffer_state_create,
649 .compress = old_streaming_compress_advanced,
650 .destroy = buffer_state_destroy,
651 };
652
653 method_t const old_streaming_cdict = {
654 .name = "old streaming cdcit",
655 .create = buffer_state_create,
656 .compress = old_streaming_compress_cdict,
657 .destroy = buffer_state_destroy,
658 };
659
660 method_t const old_streaming_advanced_cdict = {
661 .name = "old streaming advanced cdict",
662 .create = buffer_state_create,
663 .compress = old_streaming_compress_cdict_advanced,
664 .destroy = buffer_state_destroy,
665 };
666
667 method_t const cli = {
668 .name = "zstdcli",
669 .create = method_state_create,
670 .compress = cli_compress,
671 .destroy = method_state_destroy,
672 };
673
674 static method_t const* g_methods[] = {
675 &simple,
676 &compress_cctx,
677 &cli,
678 &advanced_one_pass,
679 &advanced_one_pass_small_out,
680 &advanced_streaming,
681 &old_streaming,
682 &old_streaming_advanced,
683 &old_streaming_cdict,
684 &old_streaming_advanced_cdict,
685 NULL,
686 };
687
688 method_t const* const* methods = g_methods;
689