From b3f33ccfb3b4fcc73df82126fd5ecfb751268fc6 Mon Sep 17 00:00:00 2001 From: Yann Collet Date: Sat, 9 Sep 2017 14:37:28 -0700 Subject: [PATCH] use ZSTD_decodingBufferSize_min() inside ZSTD_decompressStream() Use same definition as public one minor : reduce allocated buffer size in some cases (when frameContentSize is known and == windowSize) --- lib/decompress/zstd_decompress.c | 53 ++++++++++++++++++++------------ tests/zstreamtest.c | 10 ++++-- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index aeac17d1..5158d3f3 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -102,7 +102,8 @@ struct ZSTD_DCtx_s const void* dictEnd; /* end of previous segment */ size_t expected; ZSTD_frameHeader fParams; - blockType_e bType; /* used in ZSTD_decompressContinue(), to transfer blockType between header decoding and block decoding stages */ + U64 decodedSize; + blockType_e bType; /* used in ZSTD_decompressContinue(), store blockType between block header decoding and block decompression stages */ ZSTD_dStage stage; U32 litEntropy; U32 fseEntropy; @@ -127,7 +128,6 @@ struct ZSTD_DCtx_s size_t outBuffSize; size_t outStart; size_t outEnd; - size_t blockSize; size_t lhSize; void* legacyContext; U32 previousLegacyVersion; @@ -153,6 +153,7 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx) { dctx->expected = ZSTD_frameHeaderSize_prefix; dctx->stage = ZSTDds_getFrameHeaderSize; + dctx->decodedSize = 0; dctx->previousDstEnd = NULL; dctx->base = NULL; dctx->vBase = NULL; @@ -172,13 +173,13 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx) static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx) { ZSTD_decompressBegin(dctx); /* cannot fail */ - dctx->staticSize = 0; + dctx->staticSize = 0; dctx->maxWindowSize = ZSTD_MAXWINDOWSIZE_DEFAULT; - dctx->ddict = NULL; - dctx->ddictLocal = NULL; - dctx->inBuff = NULL; - dctx->inBuffSize = 0; - dctx->outBuffSize= 0; + dctx->ddict = NULL; + dctx->ddictLocal = NULL; + dctx->inBuff = NULL; + dctx->inBuffSize = 0; + dctx->outBuffSize = 0; dctx->streamStage = zdss_init; } @@ -1771,9 +1772,16 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c return ERROR(corruption_detected); } if (ZSTD_isError(rSize)) return rSize; + DEBUGLOG(5, "decoded size from block : %u", (U32)rSize); + dctx->decodedSize += rSize; if (dctx->fParams.checksumFlag) XXH64_update(&dctx->xxhState, dst, rSize); if (dctx->stage == ZSTDds_decompressLastBlock) { /* end of frame */ + DEBUGLOG(4, "decoded size from frame : %u", (U32)dctx->decodedSize); + if (dctx->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN) { + if (dctx->decodedSize != dctx->fParams.frameContentSize) { + return ERROR(corruption_detected); + } } if (dctx->fParams.checksumFlag) { /* another round for frame checksum */ dctx->expected = 4; dctx->stage = ZSTDds_checkChecksum; @@ -1789,8 +1797,11 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c return rSize; } case ZSTDds_checkChecksum: + DEBUGLOG(4, "case ZSTDds_checkChecksum"); + assert(srcSize == 4); /* guaranteed by dctx->expected */ { U32 const h32 = (U32)XXH64_digest(&dctx->xxhState); - U32 const check32 = MEM_readLE32(src); /* srcSize == 4, guaranteed by dctx->expected */ + U32 const check32 = MEM_readLE32(src); + DEBUGLOG(4, "calculated %08X :: %08X read", h32, check32); if (check32 != h32) return ERROR(checksum_wrong); dctx->expected = 0; dctx->stage = ZSTDds_getFrameHeaderSize; @@ -2361,15 +2372,14 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB if (zds->fParams.windowSize > zds->maxWindowSize) return ERROR(frameParameter_windowTooLarge); /* Adapt buffer sizes to frame header instructions */ - { size_t const blockSize = zds->fParams.blockSizeMax; - size_t const neededOutSize = (size_t)(zds->fParams.windowSize + blockSize + WILDCOPY_OVERLENGTH * 2); - zds->blockSize = blockSize; - if ((zds->inBuffSize < blockSize) || (zds->outBuffSize < neededOutSize)) { - size_t const bufferSize = blockSize + neededOutSize; + { size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */); + size_t const neededOutBuffSize = ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize); + if ((zds->inBuffSize < neededInBuffSize) || (zds->outBuffSize < neededOutBuffSize)) { + size_t const bufferSize = neededInBuffSize + neededOutBuffSize; DEBUGLOG(4, "inBuff : from %u to %u", - (U32)zds->inBuffSize, (U32)blockSize); + (U32)zds->inBuffSize, (U32)neededInBuffSize); DEBUGLOG(4, "outBuff : from %u to %u", - (U32)zds->outBuffSize, (U32)neededOutSize); + (U32)zds->outBuffSize, (U32)neededOutBuffSize); if (zds->staticSize) { /* static DCtx */ DEBUGLOG(4, "staticSize : %u", (U32)zds->staticSize); assert(zds->staticSize >= sizeof(ZSTD_DCtx)); /* controlled at init */ @@ -2382,9 +2392,9 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB zds->inBuff = (char*)ZSTD_malloc(bufferSize, zds->customMem); if (zds->inBuff == NULL) return ERROR(memory_allocation); } - zds->inBuffSize = blockSize; + zds->inBuffSize = neededInBuffSize; zds->outBuff = zds->inBuff + zds->inBuffSize; - zds->outBuffSize = neededOutSize; + zds->outBuffSize = neededOutBuffSize; } } zds->streamStage = zdss_read; /* fall-through */ @@ -2442,8 +2452,13 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB zds->outStart += flushedSize; if (flushedSize == toFlushSize) { /* flush completed */ zds->streamStage = zdss_read; - if (zds->outStart + zds->blockSize > zds->outBuffSize) + if ( (zds->outBuffSize < zds->fParams.frameContentSize) + && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) { + DEBUGLOG(5, "restart filling outBuff from beginning (left:%i, needed:%u)", + (int)(zds->outBuffSize - zds->outStart), + (U32)zds->fParams.blockSizeMax); zds->outStart = zds->outEnd = 0; + } break; } } /* cannot complete flush */ diff --git a/tests/zstreamtest.c b/tests/zstreamtest.c index d7b2e197..8c8adc62 100644 --- a/tests/zstreamtest.c +++ b/tests/zstreamtest.c @@ -909,10 +909,16 @@ static int fuzzerTests(U32 seed, U32 nbTests, unsigned startTest, double compres inBuff.size = inBuff.pos + readCSrcSize; outBuff.size = inBuff.pos + dstBuffSize; decompressionResult = ZSTD_decompressStream(zd, &outBuff, &inBuff); - CHECK (ZSTD_isError(decompressionResult), "decompression error : %s", ZSTD_getErrorName(decompressionResult)); + if (ZSTD_getErrorCode(decompressionResult) == ZSTD_error_checksum_wrong) { + DISPLAY("checksum error : \n"); + findDiff(copyBuffer, dstBuffer, totalTestSize); + } + CHECK( ZSTD_isError(decompressionResult), "decompression error : %s", + ZSTD_getErrorName(decompressionResult) ); } CHECK (decompressionResult != 0, "frame not fully decoded"); - CHECK (outBuff.pos != totalTestSize, "decompressed data : wrong size") + CHECK (outBuff.pos != totalTestSize, "decompressed data : wrong size (%u != %u)", + (U32)outBuff.pos, (U32)totalTestSize); CHECK (inBuff.pos != cSize, "compressed data should be fully read") { U64 const crcDest = XXH64(dstBuffer, totalTestSize, 0); if (crcDest!=crcOrig) findDiff(copyBuffer, dstBuffer, totalTestSize);