diff --git a/NEWS b/NEWS index ab6b26b4..1b65b508 100644 --- a/NEWS +++ b/NEWS @@ -1,4 +1,5 @@ v1.1.1 +New : command -M, to limit allowed memory consumption v1.1.0 New : contrib/pzstd, parallel version of zstd, by Nick Terrell diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index d157e005..e6c40999 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -1573,7 +1573,7 @@ size_t ZSTD_setDStreamParameter(ZSTD_DStream* zds, switch(paramType) { default : return ERROR(parameter_unknown); - case ZSTDdsp_maxWindowSize : zds->maxWindowSize = paramValue; break; + case ZSTDdsp_maxWindowSize : zds->maxWindowSize = paramValue ? paramValue : (U32)(-1); break; } return 0; } @@ -1654,7 +1654,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB } } zds->fParams.windowSize = MAX(zds->fParams.windowSize, 1U << ZSTD_WINDOWLOG_ABSOLUTEMIN); - if (zds->fParams.windowSize > zds->maxWindowSize) return ERROR(frameParameter_unsupported); + if (zds->fParams.windowSize > zds->maxWindowSize) return ERROR(frameParameter_windowTooLarge); /* Adapt buffer sizes to frame header instructions */ { size_t const blockSize = MIN(zds->fParams.windowSize, ZSTD_BLOCKSIZE_ABSOLUTEMAX); diff --git a/programs/fileio.c b/programs/fileio.c index d57ffedd..c4c308e0 100644 --- a/programs/fileio.c +++ b/programs/fileio.c @@ -120,6 +120,9 @@ static U32 g_checksumFlag = 1; void FIO_setChecksumFlag(unsigned checksumFlag) { g_checksumFlag = checksumFlag; } static U32 g_removeSrcFile = 0; void FIO_setRemoveSrcFile(unsigned flag) { g_removeSrcFile = (flag>0); } +static U32 g_memLimit = 0; +void FIO_setMemLimit(unsigned memLimit) { g_memLimit = memLimit; } + /*-************************************* @@ -480,6 +483,7 @@ static dRess_t FIO_createDResources(const char* dictFileName) /* Allocation */ ress.dctx = ZSTD_createDStream(); if (ress.dctx==NULL) EXM_THROW(60, "Can't create ZSTD_DStream"); + ZSTD_setDStreamParameter(ress.dctx, ZSTDdsp_maxWindowSize, g_memLimit); ress.srcBufferSize = ZSTD_DStreamInSize(); ress.srcBuffer = malloc(ress.srcBufferSize); ress.dstBufferSize = ZSTD_DStreamOutSize(); diff --git a/programs/fileio.h b/programs/fileio.h index 1e89aec2..60a7e0de 100644 --- a/programs/fileio.h +++ b/programs/fileio.h @@ -15,7 +15,6 @@ extern "C" { #endif - /* ************************************* * Special i/o constants **************************************/ @@ -37,6 +36,7 @@ void FIO_setSparseWrite(unsigned sparse); /**< 0: no sparse; 1: disable on stdo void FIO_setDictIDFlag(unsigned dictIDFlag); void FIO_setChecksumFlag(unsigned checksumFlag); void FIO_setRemoveSrcFile(unsigned flag); +void FIO_setMemLimit(unsigned memLimit); /*-************************************* diff --git a/programs/zstdcli.c b/programs/zstdcli.c index f64909d9..3a284c31 100644 --- a/programs/zstdcli.c +++ b/programs/zstdcli.c @@ -135,6 +135,7 @@ static int usage_advanced(const char* programName) DISPLAY( "--test : test compressed file integrity \n"); DISPLAY( "--[no-]sparse : sparse mode (default:enabled on file, disabled on stdout)\n"); #endif + DISPLAY( " -M# : Set a memory usage limit for decompression \n"); DISPLAY( "-- : All arguments after \"--\" are treated as files \n"); #ifndef ZSTD_NODICT DISPLAY( "\n"); @@ -172,14 +173,19 @@ static void waitEnter(void) } /*! readU32FromChar() : - @return : unsigned integer value reach from input in `char` format + @return : unsigned integer value read from input in `char` format + allows and interprets K, KB, KiB, M, MB and MiB suffix. Will also modify `*stringPtr`, advancing it to position where it stopped reading. - Note : this function can overflow if digit string > MAX_UINT */ + Note : function result can overflow if digit string > MAX_UINT */ static unsigned readU32FromChar(const char** stringPtr) { unsigned result = 0; while ((**stringPtr >='0') && (**stringPtr <='9')) result *= 10, result += **stringPtr - '0', (*stringPtr)++ ; + if (toupper(**stringPtr)=='K') result <<= 10, (*stringPtr)++ ; + else if (toupper(**stringPtr)=='M') result <<= 20, (*stringPtr)++ ; + if (toupper(**stringPtr)=='i') (*stringPtr)++; + if (toupper(**stringPtr)=='B') (*stringPtr)++; return result; } @@ -206,6 +212,7 @@ int main(int argCount, const char* argv[]) int cLevel = ZSTDCLI_CLEVEL_DEFAULT; int cLevelLast = 1; unsigned recursive = 0; + unsigned memLimit = 0; const char** filenameTable = (const char**)malloc(argCount * sizeof(const char*)); /* argCount >= 1 */ unsigned filenameIdx = 0; const char* programName = argv[0]; @@ -224,8 +231,8 @@ int main(int argCount, const char* argv[]) /* init */ (void)recursive; (void)cLevelLast; /* not used when ZSTD_NOBENCH set */ (void)dictCLevel; (void)dictSelect; (void)dictID; /* not used when ZSTD_NODICT set */ - (void)decode; (void)cLevel; (void)testmode;/* not used when ZSTD_NOCOMPRESS set */ - (void)ultra; /* not used when ZSTD_NODECOMPRESS set */ + (void)decode; (void)cLevel; (void)testmode; /* not used when ZSTD_NOCOMPRESS set */ + (void)ultra; (void)memLimit; /* not used when ZSTD_NODECOMPRESS set */ if (filenameTable==NULL) { DISPLAY("zstd: %s \n", strerror(errno)); exit(1); } filenameTable[0] = stdinmark; displayOut = stderr; @@ -331,6 +338,12 @@ int main(int argCount, const char* argv[]) /* destination file name */ case 'o': nextArgumentIsOutFileName=1; lastCommand=1; argument++; break; + /* limit decompression memory */ + case 'M': + argument++; + memLimit = readU32FromChar(&argument); + break; + #ifdef UTIL_HAS_CREATEFILELIST /* recursive */ case 'r': recursive=1; argument++; break; @@ -359,10 +372,7 @@ int main(int argCount, const char* argv[]) /* cut input into blocks (benchmark only) */ case 'B': argument++; - { size_t bSize = readU32FromChar(&argument); - if (toupper(*argument)=='K') bSize<<=10, argument++; /* allows using KB notation */ - if (toupper(*argument)=='M') bSize<<=20, argument++; - if (toupper(*argument)=='B') argument++; + { size_t const bSize = readU32FromChar(&argument); BMK_setNotificationLevel(displayLevel); BMK_SetBlockSize(bSize); } @@ -395,8 +405,6 @@ int main(int argCount, const char* argv[]) nextArgumentIsMaxDict = 0; lastCommand = 0; maxDictSize = readU32FromChar(&argument); - if (toupper(*argument)=='K') maxDictSize <<= 10; - if (toupper(*argument)=='M') maxDictSize <<= 20; continue; } @@ -511,6 +519,7 @@ int main(int argCount, const char* argv[]) } else { /* decompression */ #ifndef ZSTD_NODECOMPRESS if (testmode) { outFileName=nulmark; FIO_setRemoveSrcFile(0); } /* test mode */ + FIO_setMemLimit(memLimit); if (filenameIdx==1 && outFileName) operationResult = FIO_decompressFilename(outFileName, filenameTable[0], dictFileName); else diff --git a/tests/playTests.sh b/tests/playTests.sh index d94d8fab..efb0f0cb 100755 --- a/tests/playTests.sh +++ b/tests/playTests.sh @@ -81,6 +81,8 @@ $ZSTD -dc < tmp.zst > $INTOVOID # combine decompression, stdin & stdout $ZSTD -dc - < tmp.zst > $INTOVOID $ZSTD -d < tmp.zst > $INTOVOID # implicit stdout when stdin is used $ZSTD -d - < tmp.zst > $INTOVOID +$ECHO "test : impose memory limitation (must fail)" +$ZSTD -d -f tmp.zst -M2K -c > $INTOVOID && die "decompression needs more memory than allowed" $ECHO "test : overwrite protection" $ZSTD -q tmp && die "overwrite check failed!" $ECHO "test : force overwrite"