Added ZSTD_get_decompressed_size

Since this implementation handles multiple concatenated frames,
to determine decompressed size we must traverse the entire input,
checking each frame's frame_content_size field
This commit is contained in:
Sean Purcell 2017-01-30 14:42:21 -08:00
parent 9700f92583
commit 5657e0e07d
3 changed files with 293 additions and 133 deletions

View File

@ -5,8 +5,8 @@
typedef unsigned char u8;
// There's no good way to determine output size without decompressing
// For this example assume we'll never decompress at a ratio larger than 16
// If the data doesn't have decompressed size with it, fallback on assuming the
// compression ratio is at most 16
#define MAX_COMPRESSION_RATIO (16)
u8 *input;
@ -14,80 +14,89 @@ u8 *output;
u8 *dict;
size_t read_file(const char *path, u8 **ptr) {
FILE *f = fopen(path, "rb");
if (!f) {
fprintf(stderr, "failed to open file %s\n", path);
exit(1);
}
fseek(f, 0L, SEEK_END);
size_t size = ftell(f);
rewind(f);
*ptr = malloc(size);
if (!ptr) {
fprintf(stderr, "failed to allocate memory to hold %s\n", path);
exit(1);
}
size_t pos = 0;
while (!feof(f)) {
size_t read = fread(&(*ptr)[pos], 1, size, f);
if (ferror(f)) {
fprintf(stderr, "error while reading file %s\n", path);
exit(1);
FILE *f = fopen(path, "rb");
if (!f) {
fprintf(stderr, "failed to open file %s\n", path);
exit(1);
}
pos += read;
}
fclose(f);
fseek(f, 0L, SEEK_END);
size_t size = ftell(f);
rewind(f);
return pos;
*ptr = malloc(size);
if (!ptr) {
fprintf(stderr, "failed to allocate memory to hold %s\n", path);
exit(1);
}
size_t pos = 0;
while (!feof(f)) {
size_t read = fread(&(*ptr)[pos], 1, size, f);
if (ferror(f)) {
fprintf(stderr, "error while reading file %s\n", path);
exit(1);
}
pos += read;
}
fclose(f);
return pos;
}
void write_file(const char *path, const u8 *ptr, size_t size) {
FILE *f = fopen(path, "wb");
FILE *f = fopen(path, "wb");
size_t written = 0;
while (written < size) {
written += fwrite(&ptr[written], 1, size, f);
if (ferror(f)) {
fprintf(stderr, "error while writing file %s\n", path);
exit(1);
size_t written = 0;
while (written < size) {
written += fwrite(&ptr[written], 1, size, f);
if (ferror(f)) {
fprintf(stderr, "error while writing file %s\n", path);
exit(1);
}
}
}
fclose(f);
fclose(f);
}
int main(int argc, char **argv) {
if (argc < 3) {
fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n", argv[0]);
if (argc < 3) {
fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n",
argv[0]);
return 1;
}
return 1;
}
size_t input_size = read_file(argv[1], &input);
size_t dict_size = 0;
if (argc >= 4) {
dict_size = read_file(argv[3], &dict);
}
size_t input_size = read_file(argv[1], &input);
size_t dict_size = 0;
if (argc >= 4) {
dict_size = read_file(argv[3], &dict);
}
output = malloc(MAX_COMPRESSION_RATIO * input_size);
if (!output) {
fprintf(stderr, "failed to allocate memory\n");
return 1;
}
size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size);
if (decompressed_size == -1) {
decompressed_size = MAX_COMPRESSION_RATIO * input_size;
fprintf(stderr, "WARNING: Compressed data does contain decompressed "
"size, going to assume the compression ratio is at "
"most %d (decompressed size of at most %lld\n",
MAX_COMPRESSION_RATIO, decompressed_size);
}
output = malloc(decompressed_size);
if (!output) {
fprintf(stderr, "failed to allocate memory\n");
return 1;
}
size_t decompressed =
ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO,
input, input_size, dict, dict_size);
size_t decompressed =
ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO,
input, input_size, dict, dict_size);
write_file(argv[2], output, decompressed);
write_file(argv[2], output, decompressed);
free(input);
free(output);
free(dict);
input = output = dict = NULL;
free(input);
free(output);
free(dict);
input = output = dict = NULL;
}

View File

@ -16,6 +16,10 @@ size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src,
size_t src_len, const void *dict,
size_t dict_len);
/// Get the decompressed size of an input stream so memory can be allocated in
/// advance
size_t ZSTD_get_decompressed_size(const void *src, size_t src_len);
/******* UTILITY MACROS AND TYPES *********************************************/
#define MAX_WINDOW_SIZE ((size_t)512 << 20)
// Max block size decompressed size is 128 KB and literal blocks must be smaller
@ -232,10 +236,30 @@ typedef struct {
size_t src_len;
} io_streams_t;
/// A small structure that can be reused in various places that need to access
/// frame header information
typedef struct {
// The size of window that we need to be able to contiguously store for
// references
size_t window_size;
// The total output size of this compressed frame
size_t frame_content_size;
// The dictionary id if this frame uses one
u32 dictionary_id;
// Whether or not the content of this frame has a checksum
int content_checksum_flag;
// Whether or not the output for this frame is in a single segment
int single_segment_flag;
// The size in bytes of this header
int header_size;
} frame_header_t;
/// The context needed to decode blocks in a frame
typedef struct {
size_t window_size;
size_t frame_content_size;
frame_header_t header;
// The total amount of data available for backreferences, to determine if an
// offset too large to be correct
@ -255,12 +279,6 @@ typedef struct {
// The last 3 offsets for the special "repeat offsets". Array size is 4 so
// that previous_offsets[1] corresponds to the most recent offset
u64 previous_offsets[4];
// The dictionary id for this frame if one exists
u32 dictionary_id;
int single_segment_flag;
int content_checksum_flag;
} frame_context_t;
/// The decoded contents of a dictionary so that it doesn't have to be repeated
@ -364,10 +382,11 @@ size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src,
/******* FRAME DECODING ******************************************************/
static void decode_data_frame(io_streams_t *streams, dictionary_t *dict);
static void init_frame_context(frame_context_t *context);
static void free_frame_context(frame_context_t *context);
static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
static void init_frame_context(io_streams_t *streams, frame_context_t *context,
dictionary_t *dict);
static void free_frame_context(frame_context_t *context);
static void parse_frame_header(frame_header_t *header, const u8 *src,
size_t src_len);
static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict);
static void decompress_data(io_streams_t *streams, frame_context_t *ctx);
@ -411,12 +430,10 @@ static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) {
frame_context_t ctx;
// Initialize the context that needs to be carried from block to block
init_frame_context(&ctx);
parse_frame_header(streams, &ctx, dict);
frame_context_apply_dict(&ctx, dict);
init_frame_context(streams, &ctx, dict);
if (ctx.frame_content_size != 0 &&
ctx.frame_content_size > streams->dst_len) {
if (ctx.header.frame_content_size != 0 &&
ctx.header.frame_content_size > streams->dst_len) {
OUT_SIZE();
}
@ -425,13 +442,40 @@ static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) {
free_frame_context(&ctx);
}
static void init_frame_context(frame_context_t *context) {
/// Takes the information provided in the header and dictionary, and initializes
/// the context for this frame
static void init_frame_context(io_streams_t *streams, frame_context_t *context,
dictionary_t *dict) {
memset(context, 0x00, sizeof(frame_context_t));
// Parse data from the frame header
parse_frame_header(&context->header, streams->src, streams->src_len);
streams->src += context->header.header_size;
streams->src_len -= context->header.header_size;
// Set up the offset history for the repeat offset commands
context->previous_offsets[1] = 1;
context->previous_offsets[2] = 4;
context->previous_offsets[3] = 8;
{
// Allocate the window buffer
size_t buffer_size;
if (context->header.single_segment_flag) {
buffer_size = context->header.frame_content_size +
(dict ? dict->content_size : 0);
} else {
buffer_size = context->header.window_size;
}
if (buffer_size > MAX_WINDOW_SIZE) {
ERROR("Requested window size too large");
}
cbuf_init(&context->window, buffer_size);
}
// Apply details from the dict if it exists
frame_context_apply_dict(context, dict);
}
static void free_frame_context(frame_context_t *context) {
@ -446,13 +490,13 @@ static void free_frame_context(frame_context_t *context) {
memset(context, 0, sizeof(frame_context_t));
}
static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
dictionary_t *dict) {
if (streams->src_len < 1) {
static void parse_frame_header(frame_header_t *header, const u8 *src,
size_t src_len) {
if (src_len < 1) {
INP_SIZE();
}
u8 descriptor = read_bits_LE(streams->src, 8, 0);
u8 descriptor = read_bits_LE(src, 8, 0);
// decode frame header descriptor into flags
u8 frame_content_size_flag = descriptor >> 6;
@ -465,30 +509,28 @@ static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
CORRUPTION();
}
streams->src++;
streams->src_len--;
int header_size = 1;
ctx->single_segment_flag = single_segment_flag;
ctx->content_checksum_flag = content_checksum_flag;
header->single_segment_flag = single_segment_flag;
header->content_checksum_flag = content_checksum_flag;
// decode window size
if (!single_segment_flag) {
if (streams->src_len < 1) {
if (src_len < header_size + 1) {
INP_SIZE();
}
// Use the algorithm from the specification to compute window size
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
u8 window_descriptor = read_bits_LE(streams->src, 8, 0);
u8 window_descriptor = src[header_size];
u8 exponent = window_descriptor >> 3;
u8 mantissa = window_descriptor & 7;
size_t window_base = (size_t)1 << (10 + exponent);
size_t window_add = (window_base / 8) * mantissa;
ctx->window_size = window_base + window_add;
header->window_size = window_base + window_add;
streams->src++;
streams->src_len--;
header_size++;
}
// decode dictionary id if it exists
@ -496,52 +538,40 @@ static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
const int bytes_array[] = {0, 1, 2, 4};
const int bytes = bytes_array[dictionary_id_flag];
if (streams->src_len < bytes) {
if (src_len < header_size + bytes) {
INP_SIZE();
}
ctx->dictionary_id = read_bits_LE(streams->src, bytes * 8, 0);
streams->src += bytes;
streams->src_len -= bytes;
header->dictionary_id = read_bits_LE(src + header_size, bytes * 8, 0);
header_size += bytes;
} else {
ctx->dictionary_id = 0;
header->dictionary_id = 0;
}
// decode frame content size if it exists
if (single_segment_flag || frame_content_size_flag) {
// if frame_content_size_flag == 0 but single_segment_flag is set, we
// still
// have a 1 byte field
// still have a 1 byte field
const int bytes_array[] = {1, 2, 4, 8};
const int bytes = bytes_array[frame_content_size_flag];
if (streams->src_len < bytes) {
if (src_len < header_size + bytes) {
INP_SIZE();
}
ctx->frame_content_size = read_bits_LE(streams->src, bytes * 8, 0);
header->frame_content_size =
read_bits_LE(src + header_size, bytes * 8, 0);
if (bytes == 2) {
ctx->frame_content_size += 256;
header->frame_content_size += 256;
}
streams->src += bytes;
streams->src_len -= bytes;
header_size += bytes;
} else {
header->frame_content_size = 0;
}
if (single_segment_flag) {
ctx->window_size =
ctx->frame_content_size + (dict ? dict->content_size : 0);
// We need to allocate a buffer to write to of size at least output +
// dict
// size
size_t size = ctx->frame_content_size + (dict ? dict->content_size : 0);
}
// Allocate the window
if (ctx->window_size > MAX_WINDOW_SIZE) {
ERROR("Requested window size too large");
}
cbuf_init(&ctx->window, ctx->window_size);
header->header_size = header_size;
}
/// A dictionary acts as initializing values for the frame context before
@ -552,7 +582,7 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) {
if (!dict || !dict->content)
return;
if (ctx->dictionary_id == 0 && dict->dictionary_id != 0) {
if (ctx->header.dictionary_id == 0 && dict->dictionary_id != 0) {
// The dictionary is unneeded, and shouldn't be used as it may interfere
// with the default offset history
return;
@ -560,7 +590,8 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) {
// If the dictionary id is 0, it doesn't matter if we provide the wrong raw
// content dict, it won't change anything
if (ctx->dictionary_id != 0 && ctx->dictionary_id != dict->dictionary_id) {
if (ctx->header.dictionary_id != 0 &&
ctx->header.dictionary_id != dict->dictionary_id) {
ERROR("Wrong/no dictionary provided");
}
@ -575,8 +606,7 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) {
// be used in the table repeat modes
if (dict->dictionary_id != 0) {
// Deep copy the entropy tables so they can be freed independently of
// the
// dictionary struct
// the dictionary struct
HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
@ -590,14 +620,14 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) {
/// Decompress the data from a frame block by block
static void decompress_data(io_streams_t *streams, frame_context_t *ctx) {
u8 last_block = 0;
int last_block = 0;
do {
if (streams->src_len < 3) {
INP_SIZE();
}
// Parse the block header
last_block = streams->src[0] & 1;
u8 block_type = (streams->src[0] >> 1) & 3;
int block_type = (streams->src[0] >> 1) & 3;
size_t block_len = read_bits_LE(streams->src, 21, 3);
streams->src += 3;
@ -648,6 +678,10 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) {
// Compressed block, this is mode complex
decompress_block(streams, ctx, block_len);
break;
case 3:
// Reserved block type
CORRUPTION();
break;
}
} while (!last_block);
@ -656,10 +690,9 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) {
streams->dst += written;
streams->dst_len -= written;
if (ctx->content_checksum_flag) {
if (ctx->header.content_checksum_flag) {
// This program does not support checking the checksum, so skip over it
// if
// it's present
// if it's present
if (streams->src_len < 4) {
INP_SIZE();
}
@ -1312,6 +1345,126 @@ static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx,
}
/******* END SEQUENCE EXECUTION ***********************************************/
/******* OUTPUT SIZE COUNTING *************************************************/
size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len);
/// Get the decompressed size of an input stream so memory can be allocated in
/// advance.
/// This is more complex than the implementation in the reference
/// implementation, as this API allows for the decompression of multiple
/// concatenated frames.
size_t ZSTD_get_decompressed_size(const void *src, size_t src_len) {
const u8 *ip = (const u8 *) src;
size_t dst_size = 0;
// Each frame header only gives us the size of its frame, so iterate over all
// frames
while (src_len > 0) {
if (src_len < 4) {
INP_SIZE();
}
u32 magic_number = read_bits_LE(ip, 32, 0);
ip += 4;
src_len -= 4;
if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) {
// skippable frame, this has no impact on output size
if (src_len < 4) {
INP_SIZE();
}
size_t frame_size = read_bits_LE(ip, 32, 32);
if (src_len < 4 + frame_size) {
INP_SIZE();
}
// skip over frame
ip += 4 + frame_size;
src_len -= 4 + frame_size;
} else if (magic_number == 0xFD2FB528U) {
// ZSTD frame
frame_header_t header;
parse_frame_header(&header, ip, src_len);
if (header.frame_content_size == 0 && !header.single_segment_flag) {
// Content size not provided, we can't tell
return -1;
}
dst_size += header.frame_content_size;
// we need to traverse the frame to find when the next one starts
size_t traversed = traverse_frame(&header, ip, src_len);
ip += traversed;
src_len -= traversed;
} else {
// not a real frame
ERROR("Invalid magic number");
}
}
return dst_size;
}
/// Iterate over each block in a frame to find the end of it, to get to the
/// start of the next frame
size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len) {
const u8 *const src_beg = src;
const u8 *const src_end = src + src_len;
src += header->header_size;
src_len += header->header_size;
int last_block = 0;
do {
if (src + 3 > src_end) {
INP_SIZE();
}
// Parse the block header
last_block = src[0] & 1;
int block_type = (src[0] >> 1) & 3;
size_t block_len = read_bits_LE(src, 21, 3);
src += 3;
switch (block_type) {
case 0: // Raw block, block_len bytes
if (src + block_len > src_end) {
INP_SIZE();
}
src += block_len;
break;
case 1: // RLE block, 1 byte
if (src + 1 > src_end) {
INP_SIZE();
}
src++;
break;
case 2: // Compressed block, compressed size is block_len
if (src + block_len > src_end) {
INP_SIZE();
}
src += block_len;
break;
case 3:
// Reserved block type
CORRUPTION();
break;
}
} while (!last_block);
if (header->content_checksum_flag) {
if (src + 4 > src_end) {
INP_SIZE();
}
src += 4;
}
return src - src_beg;
}
/******* END OUTPUT SIZE COUNTING *********************************************/
/******* DICTIONARY PARSING ***************************************************/
static void init_raw_content_dict(dictionary_t *dict, const u8 *src,
size_t src_len);
@ -1952,8 +2105,8 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs,
high_threshold); // Make sure we don't occupy a spot taken
// by the low prob symbols
// Note: no other collision checking is necessary as `step` is
// coprime to
// `size`, so the cycle will visit each position exactly once
// coprime to `size`, so the cycle will visit each position exactly
// once
}
}
if (pos != 0) {
@ -1964,13 +2117,11 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs,
for (int i = 0; i < size; i++) {
u8 symbol = dtable->symbols[i];
u16 next_state_desc = state_desc[symbol]++;
// Fills in the table appropriately
// next_state_desc increases by symbol over time, decreasing number of
// bits
// Fills in the table appropriately next_state_desc increases by symbol
// over time, decreasing number of bits
dtable->num_bits[i] = (u8)(accuracy_log - log2inf(next_state_desc));
// baseline increases until the bit threshold is passed, at which point
// it
// resets to 0
// it resets to 0
dtable->new_state_base[i] =
((u16)next_state_desc << dtable->num_bits[i]) - size;
}
@ -2057,8 +2208,7 @@ static void FSE_init_dtable_rle(FSE_dtable *dtable, u8 symb) {
dtable->new_state_base = malloc(sizeof(u16));
// This setup will always have a state of 0, always return symbol `symb`,
// and
// never consume any bits
// and never consume any bits
dtable->symbols[0] = symb;
dtable->num_bits[0] = 0;
dtable->new_state_base[0] = 0;

View File

@ -3,4 +3,5 @@ size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src,
size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src,
size_t src_len, const void *dict,
size_t dict_len);
size_t ZSTD_get_decompressed_size(const void *src, size_t src_len);