diff --git a/contrib/educational_decoder/zstd_decompress.c b/contrib/educational_decoder/zstd_decompress.c index 79fd2685..90d4a522 100644 --- a/contrib/educational_decoder/zstd_decompress.c +++ b/contrib/educational_decoder/zstd_decompress.c @@ -30,10 +30,6 @@ size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, size_t ZSTD_get_decompressed_size(const void *src, size_t src_len); /******* UTILITY MACROS AND TYPES *********************************************/ -// Specification recommends supporting at least 8MB. The maximum possible value -// is 1.875TB, but this implementation limits it to 512MB to avoid allocating -// too much memory. -#define MAX_WINDOW_SIZE ((size_t)512 * 1024 * 1024) // Max block size decompressed size is 128 KB and literal blocks must be smaller // than that #define MAX_LITERALS_SIZE ((size_t)128 * 1024) @@ -69,43 +65,6 @@ typedef int64_t i64; /// file. They implement low-level functionality needed for the higher level /// decompression functions. -/*** CIRCULAR BUFFER ******************/ -/// A standard circular buffer, used to facilitate back reference commands -typedef struct { - u8 *ptr; - size_t idx, last_flush, size; -} cbuf_t; - -/// Initialize a circular buffer -static void cbuf_init(cbuf_t *buf, size_t size); -static void cbuf_free(cbuf_t *buf); - -/// Copies up to `src_len` bytes from `src` into the buffer, stopping if it -/// would need to flush. -/// Returns the total amount of data copied. -static size_t cbuf_write_data(cbuf_t *buf, const u8 *src, size_t src_len); -/// Copies `len` bytes from `offset` back in the buffer, stopping if it would -/// need to flush. -/// Returns the number of bytes copied. -static size_t cbuf_copy_offset(cbuf_t *buf, size_t offset, size_t len); -/// Writes up to `len` copies of `byte`, stopping if would need to flush. -/// Returns the number of bytes copied. -static size_t cbuf_repeat_byte(cbuf_t *buf, u8 byte, size_t len); - -/// The `full` versions of the above functions write the full amount requested, -/// flushing to `out` when necessary. -/// They return the number of bytes flushed to `out`, if any. -static size_t cbuf_write_data_full(cbuf_t *buf, const u8 *src, size_t src_len, - u8 *out, size_t out_len); -static size_t cbuf_copy_offset_full(cbuf_t *buf, size_t offset, size_t len, - u8 *out, size_t out_len); -static size_t cbuf_repeat_byte_full(cbuf_t *buf, u8 byte, size_t len, u8 *out, - size_t out_len); - -/// Flushes any unflushed data to `dst` -static size_t cbuf_flush(cbuf_t *buf, u8 *dst, size_t dst_len); -/*** END CIRCULAR BUFFER **************/ - /*** BITSTREAM OPERATIONS *************/ /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits static inline u64 read_bits_LE(const u8 *src, int num, size_t offset); @@ -277,12 +236,11 @@ typedef struct { // offset too large to be correct size_t current_total_output; - // A sliding window of the past `window_size` bytes decoded - cbuf_t window; + const u8 *dict_content; + size_t dict_content_len; // Entropy encoding tables so they can be repeated by future blocks instead - // of - // retransmitting + // of retransmitting HUF_dtable literals_dtable; FSE_dtable ll_dtable; FSE_dtable ml_dtable; @@ -470,22 +428,6 @@ static void init_frame_context(io_streams_t *streams, frame_context_t *context, context->previous_offsets[2] = 4; context->previous_offsets[3] = 8; - { - // Allocate the window buffer - size_t buffer_size; - if (context->header.single_segment_flag) { - buffer_size = context->header.frame_content_size + - (dict ? dict->content_size : 0); - } else { - buffer_size = context->header.window_size; - } - - if (buffer_size > MAX_WINDOW_SIZE) { - ERROR("Requested window size too large"); - } - cbuf_init(&context->window, buffer_size); - } - // Apply details from the dict if it exists frame_context_apply_dict(context, dict); } @@ -497,8 +439,6 @@ static void free_frame_context(frame_context_t *context) { FSE_free_dtable(&context->ml_dtable); FSE_free_dtable(&context->of_dtable); - cbuf_free(&context->window); - memset(context, 0, sizeof(frame_context_t)); } @@ -583,6 +523,13 @@ static void parse_frame_header(frame_header_t *header, const u8 *src, header->frame_content_size = 0; } + if (single_segment_flag) { + // in this case the effective window size is frame_content_size this + // impacts sequence decoding as we need to determine whether to fall + // back to the dictionary or not on large offsets + header->window_size = header->frame_content_size; + } + header->header_size = header_size; } @@ -607,12 +554,9 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) { ERROR("Wrong/no dictionary provided"); } - // Write the dict data in, and then flush to NULL so it's not sent to the - // output stream - cbuf_write_data_full(&ctx->window, dict->content, dict->content_size, NULL, - -1); - cbuf_flush(&ctx->window, NULL, -1); - ctx->current_total_output = dict->content_size; + // Copy the pointer in so we can reference it in sequence execution + ctx->dict_content = dict->content; + ctx->dict_content_len = dict->content_size; // If it's a formatted dict copy the precomputed tables in so they can // be used in the table repeat modes @@ -655,15 +599,16 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { OUT_SIZE(); } - // Write the raw data into the window buffer - size_t written = - cbuf_write_data_full(&ctx->window, streams->src, block_len, - streams->dst, streams->dst_len); + // Copy the raw data into the output + memcpy(streams->dst, streams->src, block_len); + streams->src += block_len; streams->src_len -= block_len; - streams->dst += written; - streams->dst_len -= written; + streams->dst += block_len; + streams->dst_len -= block_len; + + ctx->current_total_output += block_len; break; } case 1: { @@ -675,15 +620,16 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { OUT_SIZE(); } - // Write streams->src[0] into the buffer block_len times - size_t written = - cbuf_repeat_byte_full(&ctx->window, streams->src[0], block_len, - streams->dst, streams->dst_len); - streams->dst += written; - streams->dst_len -= written; + // Copy `block_len` copies of `streams->src[0]` to the output + memset(streams->dst, streams->src[0], block_len); + + streams->dst += block_len; + streams->dst_len -= block_len; streams->src += 1; streams->src_len -= 1; + + ctx->current_total_output += block_len; break; } case 2: @@ -697,11 +643,6 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { } } while (!last_block); - // Flush out anything left in the window buffer to the destination stream - size_t written = cbuf_flush(&ctx->window, streams->dst, streams->dst_len); - streams->dst += written; - streams->dst_len -= written; - if (ctx->header.content_checksum_flag) { // This program does not support checking the checksum, so skip over it // if it's present @@ -1277,20 +1218,19 @@ static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx, CORRUPTION(); } - { - // Copy literals to the buffer - size_t written = - cbuf_write_data_full(&ctx->window, literals, seq.literal_length, - streams->dst, streams->dst_len); - - literals += seq.literal_length; - literals_len -= seq.literal_length; - - streams->dst += written; - streams->dst_len -= written; - - total_output += seq.literal_length; + if (streams->dst_len < seq.literal_length + seq.match_length) { + OUT_SIZE(); } + // Copy literals to output + memcpy(streams->dst, literals, seq.literal_length); + + literals += seq.literal_length; + literals_len -= seq.literal_length; + + streams->dst += seq.literal_length; + streams->dst_len -= seq.literal_length; + + total_output += seq.literal_length; size_t offset; @@ -1324,36 +1264,50 @@ static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx, offset_hist[1] = offset; } - if (offset > total_output) { - CORRUPTION(); + size_t match_length = seq.match_length; + if (total_output <= ctx->header.window_size) { + // In this case offset might go back into the dictionary + if (offset > total_output + ctx->dict_content_len) { + // The offset goes beyond even the dictionary + CORRUPTION(); + } + + if (offset > total_output) { + const size_t dict_copy = + MIN(offset - total_output, match_length); + const size_t dict_offset = + ctx->dict_content_len - (offset - total_output); + for (size_t i = 0; i < dict_copy; i++) { + *streams->dst++ = ctx->dict_content[dict_offset + i]; + } + match_length -= dict_copy; + } } - { - // Do the offset copy operation - size_t written = - cbuf_copy_offset_full(&ctx->window, offset, seq.match_length, - streams->dst, streams->dst_len); - - streams->dst += written; - streams->dst_len -= written; - total_output += seq.match_length; + // We must copy byte by byte because the match length might be larger + // than the offset + // ex: if the output so far was "abc", a command with offset=3 and + // match_length=6 would produce "abcabcabc" as the new output + for (size_t i = 0; i < match_length; i++) { + *streams->dst = *(streams->dst - offset); + streams->dst++; } + + streams->dst_len -= seq.match_length; + total_output += seq.match_length; } - { - // Copy any leftover literal bytes - size_t written = - cbuf_write_data_full(&ctx->window, literals, literals_len, - streams->dst, streams->dst_len); - streams->dst += written; - streams->dst_len -= written; - - total_output += literals_len; + if (streams->dst_len < literals_len) { + OUT_SIZE(); } + // Copy any leftover literals + memcpy(streams->dst, literals, literals_len); + streams->dst += literals_len; + streams->dst_len -= literals_len; + + total_output += literals_len; ctx->current_total_output = total_output; - - return total_output; } /******* END SEQUENCE EXECUTION ***********************************************/