diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index 28f88cde..254e2923 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -444,17 +444,22 @@ typedef struct { unsigned long long decompressedBound; } ZSTD_frameSizeInfo; +static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret) +{ + ZSTD_frameSizeInfo frameSizeInfo; + frameSizeInfo.compressedSize = ret; + frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_UNKNOWN; + return frameSizeInfo; +} + static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize) { ZSTD_frameSizeInfo frameSizeInfo; memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo)); #if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1) - if (ZSTD_isLegacy(src, srcSize)) { - frameSizeInfo.compressedSize = ZSTD_findFrameCompressedSizeLegacy(src, srcSize); - frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR; - return frameSizeInfo; - } + if (ZSTD_isLegacy(src, srcSize)) + return ZSTD_errorFrameSizeInfo(ZSTD_findFrameCompressedSizeLegacy(src, srcSize)); #endif if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE) @@ -471,16 +476,10 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize /* Extract Frame Header */ { size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize); - if (ZSTD_isError(ret)) { - frameSizeInfo.compressedSize = ret; - frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR; - return frameSizeInfo; - } - if (ret > 0) { - frameSizeInfo.compressedSize = ERROR(srcSize_wrong); - frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR; - return frameSizeInfo; - } + if (ZSTD_isError(ret)) + return ZSTD_errorFrameSizeInfo(ret); + if (ret > 0) + return ZSTD_errorFrameSizeInfo(ERROR(srcSize_wrong)); } ip += zfh.headerSize; @@ -490,17 +489,11 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize while (1) { blockProperties_t blockProperties; size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSize, &blockProperties); - if (ZSTD_isError(cBlockSize)) { - frameSizeInfo.compressedSize = cBlockSize; - frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR; - return frameSizeInfo; - } + if (ZSTD_isError(cBlockSize)) + return ZSTD_errorFrameSizeInfo(cBlockSize); - if (ZSTD_blockHeaderSize + cBlockSize > remainingSize) { - frameSizeInfo.compressedSize = ERROR(srcSize_wrong); - frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR; - return frameSizeInfo; - } + if (ZSTD_blockHeaderSize + cBlockSize > remainingSize) + return ZSTD_errorFrameSizeInfo(ERROR(srcSize_wrong)); ip += ZSTD_blockHeaderSize + cBlockSize; remainingSize -= ZSTD_blockHeaderSize + cBlockSize; @@ -511,11 +504,8 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize /* Final frame content checksum */ if (zfh.checksumFlag) { - if (remainingSize < 4) { - frameSizeInfo.compressedSize = ERROR(srcSize_wrong); - frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR; - return frameSizeInfo; - } + if (remainingSize < 4) + return ZSTD_errorFrameSizeInfo(ERROR(srcSize_wrong)); ip += 4; } @@ -534,7 +524,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize * @return : the compressed size of the frame starting at `src` */ size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize) { - ZSTD_frameSizeInfo frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); return frameSizeInfo.compressedSize; } @@ -550,9 +540,9 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) unsigned long long bound = 0; /* Iterate over each frame */ while (srcSize > 0) { - ZSTD_frameSizeInfo frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); - size_t compressedSize = frameSizeInfo.compressedSize; - unsigned long long decompressedBound = frameSizeInfo.decompressedBound; + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); + size_t const compressedSize = frameSizeInfo.compressedSize; + unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) { return ZSTD_CONTENTSIZE_ERROR; } diff --git a/lib/zstd.h b/lib/zstd.h index e49a4048..182ff376 100644 --- a/lib/zstd.h +++ b/lib/zstd.h @@ -1104,16 +1104,16 @@ typedef enum { ZSTDLIB_API unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize); /** ZSTD_decompressBound() : - * currently incompatible with legacy mode - * `src` must point to the start of a ZSTD frame or a skippeable frame - * `srcSize` must be at least as large as the frame contained + * `src` should point the start of a series of ZSTD encoded and/or skippable frames + * `srcSize` must be the _exact_ size of this series + * (i.e. there should be a frame boundary exactly at `srcSize` bytes after `src`) * @return : - the maximum decompressed size of the compressed source * - if an error occured: ZSTD_CONTENTSIZE_ERROR * * note 1 : an error can occur if `src` points to a legacy frame or an invalid/incorrectly formatted frame. - * note 2 : the bound is exact when Frame_Content_Size field is available in EVERY frame of `src`. + * note 2 : the bound is exact when Frame_Content_Size field is available in _every_ frame of `src`. * note 3 : when Frame_Content_Size isn't provided, the upper-bound for that frame is calculated by: - * upper-bound = min(128 KB, Window_Size) + * upper-bound = # blocks * min(128 KB, Window_Size) * note 4 : we always use Frame_Content_Size to bound the decompressed frame size if it's present. */ ZSTDLIB_API unsigned long long ZSTD_decompressBound(const void* src, size_t srcSice);