use ZSTD_decodingBufferSize_min() inside ZSTD_decompressStream()

Use same definition as public one
minor : reduce allocated buffer size in some cases
(when frameContentSize is known and == windowSize)
This commit is contained in:
Yann Collet 2017-09-09 14:37:28 -07:00
parent 058ed2ad33
commit b3f33ccfb3
2 changed files with 42 additions and 21 deletions

View File

@ -102,7 +102,8 @@ struct ZSTD_DCtx_s
const void* dictEnd; /* end of previous segment */ const void* dictEnd; /* end of previous segment */
size_t expected; size_t expected;
ZSTD_frameHeader fParams; ZSTD_frameHeader fParams;
blockType_e bType; /* used in ZSTD_decompressContinue(), to transfer blockType between header decoding and block decoding stages */ U64 decodedSize;
blockType_e bType; /* used in ZSTD_decompressContinue(), store blockType between block header decoding and block decompression stages */
ZSTD_dStage stage; ZSTD_dStage stage;
U32 litEntropy; U32 litEntropy;
U32 fseEntropy; U32 fseEntropy;
@ -127,7 +128,6 @@ struct ZSTD_DCtx_s
size_t outBuffSize; size_t outBuffSize;
size_t outStart; size_t outStart;
size_t outEnd; size_t outEnd;
size_t blockSize;
size_t lhSize; size_t lhSize;
void* legacyContext; void* legacyContext;
U32 previousLegacyVersion; U32 previousLegacyVersion;
@ -153,6 +153,7 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx)
{ {
dctx->expected = ZSTD_frameHeaderSize_prefix; dctx->expected = ZSTD_frameHeaderSize_prefix;
dctx->stage = ZSTDds_getFrameHeaderSize; dctx->stage = ZSTDds_getFrameHeaderSize;
dctx->decodedSize = 0;
dctx->previousDstEnd = NULL; dctx->previousDstEnd = NULL;
dctx->base = NULL; dctx->base = NULL;
dctx->vBase = NULL; dctx->vBase = NULL;
@ -178,7 +179,7 @@ static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx)
dctx->ddictLocal = NULL; dctx->ddictLocal = NULL;
dctx->inBuff = NULL; dctx->inBuff = NULL;
dctx->inBuffSize = 0; dctx->inBuffSize = 0;
dctx->outBuffSize= 0; dctx->outBuffSize = 0;
dctx->streamStage = zdss_init; dctx->streamStage = zdss_init;
} }
@ -1771,9 +1772,16 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
return ERROR(corruption_detected); return ERROR(corruption_detected);
} }
if (ZSTD_isError(rSize)) return rSize; if (ZSTD_isError(rSize)) return rSize;
DEBUGLOG(5, "decoded size from block : %u", (U32)rSize);
dctx->decodedSize += rSize;
if (dctx->fParams.checksumFlag) XXH64_update(&dctx->xxhState, dst, rSize); if (dctx->fParams.checksumFlag) XXH64_update(&dctx->xxhState, dst, rSize);
if (dctx->stage == ZSTDds_decompressLastBlock) { /* end of frame */ if (dctx->stage == ZSTDds_decompressLastBlock) { /* end of frame */
DEBUGLOG(4, "decoded size from frame : %u", (U32)dctx->decodedSize);
if (dctx->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN) {
if (dctx->decodedSize != dctx->fParams.frameContentSize) {
return ERROR(corruption_detected);
} }
if (dctx->fParams.checksumFlag) { /* another round for frame checksum */ if (dctx->fParams.checksumFlag) { /* another round for frame checksum */
dctx->expected = 4; dctx->expected = 4;
dctx->stage = ZSTDds_checkChecksum; dctx->stage = ZSTDds_checkChecksum;
@ -1789,8 +1797,11 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
return rSize; return rSize;
} }
case ZSTDds_checkChecksum: case ZSTDds_checkChecksum:
DEBUGLOG(4, "case ZSTDds_checkChecksum");
assert(srcSize == 4); /* guaranteed by dctx->expected */
{ U32 const h32 = (U32)XXH64_digest(&dctx->xxhState); { U32 const h32 = (U32)XXH64_digest(&dctx->xxhState);
U32 const check32 = MEM_readLE32(src); /* srcSize == 4, guaranteed by dctx->expected */ U32 const check32 = MEM_readLE32(src);
DEBUGLOG(4, "calculated %08X :: %08X read", h32, check32);
if (check32 != h32) return ERROR(checksum_wrong); if (check32 != h32) return ERROR(checksum_wrong);
dctx->expected = 0; dctx->expected = 0;
dctx->stage = ZSTDds_getFrameHeaderSize; dctx->stage = ZSTDds_getFrameHeaderSize;
@ -2361,15 +2372,14 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
if (zds->fParams.windowSize > zds->maxWindowSize) return ERROR(frameParameter_windowTooLarge); if (zds->fParams.windowSize > zds->maxWindowSize) return ERROR(frameParameter_windowTooLarge);
/* Adapt buffer sizes to frame header instructions */ /* Adapt buffer sizes to frame header instructions */
{ size_t const blockSize = zds->fParams.blockSizeMax; { size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */);
size_t const neededOutSize = (size_t)(zds->fParams.windowSize + blockSize + WILDCOPY_OVERLENGTH * 2); size_t const neededOutBuffSize = ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize);
zds->blockSize = blockSize; if ((zds->inBuffSize < neededInBuffSize) || (zds->outBuffSize < neededOutBuffSize)) {
if ((zds->inBuffSize < blockSize) || (zds->outBuffSize < neededOutSize)) { size_t const bufferSize = neededInBuffSize + neededOutBuffSize;
size_t const bufferSize = blockSize + neededOutSize;
DEBUGLOG(4, "inBuff : from %u to %u", DEBUGLOG(4, "inBuff : from %u to %u",
(U32)zds->inBuffSize, (U32)blockSize); (U32)zds->inBuffSize, (U32)neededInBuffSize);
DEBUGLOG(4, "outBuff : from %u to %u", DEBUGLOG(4, "outBuff : from %u to %u",
(U32)zds->outBuffSize, (U32)neededOutSize); (U32)zds->outBuffSize, (U32)neededOutBuffSize);
if (zds->staticSize) { /* static DCtx */ if (zds->staticSize) { /* static DCtx */
DEBUGLOG(4, "staticSize : %u", (U32)zds->staticSize); DEBUGLOG(4, "staticSize : %u", (U32)zds->staticSize);
assert(zds->staticSize >= sizeof(ZSTD_DCtx)); /* controlled at init */ assert(zds->staticSize >= sizeof(ZSTD_DCtx)); /* controlled at init */
@ -2382,9 +2392,9 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
zds->inBuff = (char*)ZSTD_malloc(bufferSize, zds->customMem); zds->inBuff = (char*)ZSTD_malloc(bufferSize, zds->customMem);
if (zds->inBuff == NULL) return ERROR(memory_allocation); if (zds->inBuff == NULL) return ERROR(memory_allocation);
} }
zds->inBuffSize = blockSize; zds->inBuffSize = neededInBuffSize;
zds->outBuff = zds->inBuff + zds->inBuffSize; zds->outBuff = zds->inBuff + zds->inBuffSize;
zds->outBuffSize = neededOutSize; zds->outBuffSize = neededOutBuffSize;
} } } }
zds->streamStage = zdss_read; zds->streamStage = zdss_read;
/* fall-through */ /* fall-through */
@ -2442,8 +2452,13 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
zds->outStart += flushedSize; zds->outStart += flushedSize;
if (flushedSize == toFlushSize) { /* flush completed */ if (flushedSize == toFlushSize) { /* flush completed */
zds->streamStage = zdss_read; zds->streamStage = zdss_read;
if (zds->outStart + zds->blockSize > zds->outBuffSize) if ( (zds->outBuffSize < zds->fParams.frameContentSize)
&& (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) {
DEBUGLOG(5, "restart filling outBuff from beginning (left:%i, needed:%u)",
(int)(zds->outBuffSize - zds->outStart),
(U32)zds->fParams.blockSizeMax);
zds->outStart = zds->outEnd = 0; zds->outStart = zds->outEnd = 0;
}
break; break;
} } } }
/* cannot complete flush */ /* cannot complete flush */

View File

@ -909,10 +909,16 @@ static int fuzzerTests(U32 seed, U32 nbTests, unsigned startTest, double compres
inBuff.size = inBuff.pos + readCSrcSize; inBuff.size = inBuff.pos + readCSrcSize;
outBuff.size = inBuff.pos + dstBuffSize; outBuff.size = inBuff.pos + dstBuffSize;
decompressionResult = ZSTD_decompressStream(zd, &outBuff, &inBuff); decompressionResult = ZSTD_decompressStream(zd, &outBuff, &inBuff);
CHECK (ZSTD_isError(decompressionResult), "decompression error : %s", ZSTD_getErrorName(decompressionResult)); if (ZSTD_getErrorCode(decompressionResult) == ZSTD_error_checksum_wrong) {
DISPLAY("checksum error : \n");
findDiff(copyBuffer, dstBuffer, totalTestSize);
}
CHECK( ZSTD_isError(decompressionResult), "decompression error : %s",
ZSTD_getErrorName(decompressionResult) );
} }
CHECK (decompressionResult != 0, "frame not fully decoded"); CHECK (decompressionResult != 0, "frame not fully decoded");
CHECK (outBuff.pos != totalTestSize, "decompressed data : wrong size") CHECK (outBuff.pos != totalTestSize, "decompressed data : wrong size (%u != %u)",
(U32)outBuff.pos, (U32)totalTestSize);
CHECK (inBuff.pos != cSize, "compressed data should be fully read") CHECK (inBuff.pos != cSize, "compressed data should be fully read")
{ U64 const crcDest = XXH64(dstBuffer, totalTestSize, 0); { U64 const crcDest = XXH64(dstBuffer, totalTestSize, 0);
if (crcDest!=crcOrig) findDiff(copyBuffer, dstBuffer, totalTestSize); if (crcDest!=crcOrig) findDiff(copyBuffer, dstBuffer, totalTestSize);