Merge pull request #1555 from terrelln/load-dict

[lib] Allow ZSTD_CCtx_loadDictionary() to be called before parameters are set
This commit is contained in:
Nick Terrell 2019-03-21 17:52:57 -07:00 committed by GitHub
commit 0c7668cd06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 198 additions and 21 deletions

View File

@ -103,12 +103,31 @@ ZSTD_CCtx* ZSTD_initStaticCCtx(void *workspace, size_t workspaceSize)
return cctx;
}
/**
* Clears and frees all of the dictionaries in the CCtx.
*/
static void ZSTD_clearAllDicts(ZSTD_CCtx* cctx)
{
ZSTD_free(cctx->localDict.dictBuffer, cctx->customMem);
ZSTD_freeCDict(cctx->localDict.cdict);
memset(&cctx->localDict, 0, sizeof(cctx->localDict));
memset(&cctx->prefixDict, 0, sizeof(cctx->prefixDict));
cctx->cdict = NULL;
}
static size_t ZSTD_sizeof_localDict(ZSTD_localDict dict)
{
size_t const bufferSize = dict.dictBuffer != NULL ? dict.dictSize : 0;
size_t const cdictSize = ZSTD_sizeof_CDict(dict.cdict);
return bufferSize + cdictSize;
}
static void ZSTD_freeCCtxContent(ZSTD_CCtx* cctx)
{
assert(cctx != NULL);
assert(cctx->staticSize == 0);
ZSTD_free(cctx->workSpace, cctx->customMem); cctx->workSpace = NULL;
ZSTD_freeCDict(cctx->cdictLocal); cctx->cdictLocal = NULL;
ZSTD_clearAllDicts(cctx);
#ifdef ZSTD_MULTITHREAD
ZSTDMT_freeCCtx(cctx->mtctx); cctx->mtctx = NULL;
#endif
@ -140,7 +159,7 @@ size_t ZSTD_sizeof_CCtx(const ZSTD_CCtx* cctx)
{
if (cctx==NULL) return 0; /* support sizeof on NULL */
return sizeof(*cctx) + cctx->workSpaceSize
+ ZSTD_sizeof_CDict(cctx->cdictLocal)
+ ZSTD_sizeof_localDict(cctx->localDict)
+ ZSTD_sizeof_mtctx(cctx);
}
@ -785,6 +804,44 @@ ZSTDLIB_API size_t ZSTD_CCtx_setPledgedSrcSize(ZSTD_CCtx* cctx, unsigned long lo
return 0;
}
/**
* Initializes the local dict using the requested parameters.
* NOTE: This does not use the pledged src size, because it may be used for more
* than one compression.
*/
static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx)
{
ZSTD_localDict* const dl = &cctx->localDict;
ZSTD_compressionParameters const cParams = ZSTD_getCParamsFromCCtxParams(
&cctx->requestedParams, 0, dl->dictSize);
if (dl->dict == NULL) {
/* No local dictionary. */
assert(dl->dictBuffer == NULL);
assert(dl->cdict == NULL);
assert(dl->dictSize == 0);
return 0;
}
if (dl->cdict != NULL) {
assert(cctx->cdict == dl->cdict);
/* Local dictionary already initialized. */
return 0;
}
assert(dl->dictSize > 0);
assert(cctx->cdict == NULL);
assert(cctx->prefixDict.dict == NULL);
dl->cdict = ZSTD_createCDict_advanced(
dl->dict,
dl->dictSize,
ZSTD_dlm_byRef,
dl->dictContentType,
cParams,
cctx->customMem);
RETURN_ERROR_IF(!dl->cdict, memory_allocation);
cctx->cdict = dl->cdict;
return 0;
}
size_t ZSTD_CCtx_loadDictionary_advanced(
ZSTD_CCtx* cctx, const void* dict, size_t dictSize,
ZSTD_dictLoadMethod_e dictLoadMethod, ZSTD_dictContentType_e dictContentType)
@ -793,20 +850,20 @@ size_t ZSTD_CCtx_loadDictionary_advanced(
RETURN_ERROR_IF(cctx->staticSize, memory_allocation,
"no malloc for static CCtx");
DEBUGLOG(4, "ZSTD_CCtx_loadDictionary_advanced (size: %u)", (U32)dictSize);
ZSTD_freeCDict(cctx->cdictLocal); /* in case one already exists */
if (dict==NULL || dictSize==0) { /* no dictionary mode */
cctx->cdictLocal = NULL;
cctx->cdict = NULL;
ZSTD_clearAllDicts(cctx); /* in case one already exists */
if (dict == NULL || dictSize == 0) /* no dictionary mode */
return 0;
if (dictLoadMethod == ZSTD_dlm_byRef) {
cctx->localDict.dict = dict;
} else {
ZSTD_compressionParameters const cParams =
ZSTD_getCParamsFromCCtxParams(&cctx->requestedParams, cctx->pledgedSrcSizePlusOne-1, dictSize);
cctx->cdictLocal = ZSTD_createCDict_advanced(
dict, dictSize,
dictLoadMethod, dictContentType,
cParams, cctx->customMem);
cctx->cdict = cctx->cdictLocal;
RETURN_ERROR_IF(cctx->cdictLocal == NULL, memory_allocation);
void* dictBuffer = ZSTD_malloc(dictSize, cctx->customMem);
RETURN_ERROR_IF(!dictBuffer, memory_allocation);
memcpy(dictBuffer, dict, dictSize);
cctx->localDict.dictBuffer = dictBuffer;
cctx->localDict.dict = dictBuffer;
}
cctx->localDict.dictSize = dictSize;
cctx->localDict.dictContentType = dictContentType;
return 0;
}
@ -828,10 +885,8 @@ size_t ZSTD_CCtx_refCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict)
{
RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong);
/* Free the existing local cdict (if any) to save memory. */
FORWARD_IF_ERROR( ZSTD_freeCDict(cctx->cdictLocal) );
cctx->cdictLocal = NULL;
ZSTD_clearAllDicts(cctx);
cctx->cdict = cdict;
memset(&cctx->prefixDict, 0, sizeof(cctx->prefixDict)); /* exclusive */
return 0;
}
@ -844,7 +899,7 @@ size_t ZSTD_CCtx_refPrefix_advanced(
ZSTD_CCtx* cctx, const void* prefix, size_t prefixSize, ZSTD_dictContentType_e dictContentType)
{
RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong);
cctx->cdict = NULL; /* prefix discards any prior cdict */
ZSTD_clearAllDicts(cctx);
cctx->prefixDict.dict = prefix;
cctx->prefixDict.dictSize = prefixSize;
cctx->prefixDict.dictContentType = dictContentType;
@ -863,7 +918,7 @@ size_t ZSTD_CCtx_reset(ZSTD_CCtx* cctx, ZSTD_ResetDirective reset)
if ( (reset == ZSTD_reset_parameters)
|| (reset == ZSTD_reset_session_and_parameters) ) {
RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong);
cctx->cdict = NULL;
ZSTD_clearAllDicts(cctx);
return ZSTD_CCtxParams_reset(&cctx->requestedParams);
}
return 0;
@ -4079,6 +4134,7 @@ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx,
if (cctx->streamStage == zcss_init) {
ZSTD_CCtx_params params = cctx->requestedParams;
ZSTD_prefixDict const prefixDict = cctx->prefixDict;
FORWARD_IF_ERROR( ZSTD_initLocalDict(cctx) ); /* Init the local dict if present. */
memset(&cctx->prefixDict, 0, sizeof(cctx->prefixDict)); /* single usage */
assert(prefixDict.dict==NULL || cctx->cdict==NULL); /* only one can be set */
DEBUGLOG(4, "ZSTD_compressStream2 : transparent init stage");

View File

@ -54,6 +54,14 @@ typedef struct ZSTD_prefixDict_s {
ZSTD_dictContentType_e dictContentType;
} ZSTD_prefixDict;
typedef struct {
void* dictBuffer;
void const* dict;
size_t dictSize;
ZSTD_dictContentType_e dictContentType;
ZSTD_CDict* cdict;
} ZSTD_localDict;
typedef struct {
U32 CTable[HUF_CTABLE_SIZE_U32(255)];
HUF_repeat repeatMode;
@ -245,7 +253,7 @@ struct ZSTD_CCtx_s {
U32 frameEnded;
/* Dictionary */
ZSTD_CDict* cdictLocal;
ZSTD_localDict localDict;
const ZSTD_CDict* cdict;
ZSTD_prefixDict prefixDict; /* single-usage dictionary */

View File

@ -1285,9 +1285,13 @@ static int basicUnitTests(U32 seed, double compressibility)
{
size_t ret;
MEM_writeLE32((char*)dictBuffer+2, ZSTD_MAGIC_DICTIONARY);
/* Either operation is allowed to fail, but one must fail. */
ret = ZSTD_CCtx_loadDictionary_advanced(
cctx, (const char*)dictBuffer+2, dictSize-2, ZSTD_dlm_byRef, ZSTD_dct_auto);
if (!ZSTD_isError(ret)) goto _output_error;
if (!ZSTD_isError(ret)) {
ret = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100));
if (!ZSTD_isError(ret)) goto _output_error;
}
}
DISPLAYLEVEL(3, "OK \n");
@ -1298,6 +1302,8 @@ static int basicUnitTests(U32 seed, double compressibility)
ret = ZSTD_CCtx_loadDictionary_advanced(
cctx, (const char*)dictBuffer+2, dictSize-2, ZSTD_dlm_byRef, ZSTD_dct_rawContent);
if (ZSTD_isError(ret)) goto _output_error;
ret = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100));
if (ZSTD_isError(ret)) goto _output_error;
}
DISPLAYLEVEL(3, "OK \n");
@ -1312,6 +1318,113 @@ static int basicUnitTests(U32 seed, double compressibility)
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loading dictionary before setting parameters is the same as loading after : ", testNb++);
{
size_t size1, size2;
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
CHECK_Z( ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, 7) );
CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, CNBuffer, MIN(CNBuffSize, 10 KB)) );
size1 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
if (ZSTD_isError(size1)) goto _output_error;
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, CNBuffer, MIN(CNBuffSize, 10 KB)) );
CHECK_Z( ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, 7) );
size2 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
if (ZSTD_isError(size2)) goto _output_error;
if (size1 != size2) goto _output_error;
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loading a dictionary clears the prefix : ", testNb++);
{
CHECK_Z( ZSTD_CCtx_refPrefix(cctx, (const char*)dictBuffer, dictSize) );
CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, (const char*)dictBuffer, dictSize) );
CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loading a dictionary clears the cdict : ", testNb++);
{
ZSTD_CDict* const cdict = ZSTD_createCDict(dictBuffer, dictSize, 1);
CHECK_Z( ZSTD_CCtx_refCDict(cctx, cdict) );
CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, (const char*)dictBuffer, dictSize) );
CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
ZSTD_freeCDict(cdict);
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loading a cdict clears the prefix : ", testNb++);
{
ZSTD_CDict* const cdict = ZSTD_createCDict(dictBuffer, dictSize, 1);
CHECK_Z( ZSTD_CCtx_refPrefix(cctx, (const char*)dictBuffer, dictSize) );
CHECK_Z( ZSTD_CCtx_refCDict(cctx, cdict) );
CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
ZSTD_freeCDict(cdict);
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loading a cdict clears the dictionary : ", testNb++);
{
ZSTD_CDict* const cdict = ZSTD_createCDict(dictBuffer, dictSize, 1);
CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, (const char*)dictBuffer, dictSize) );
CHECK_Z( ZSTD_CCtx_refCDict(cctx, cdict) );
CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
ZSTD_freeCDict(cdict);
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loading a prefix clears the dictionary : ", testNb++);
{
CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, (const char*)dictBuffer, dictSize) );
CHECK_Z( ZSTD_CCtx_refPrefix(cctx, (const char*)dictBuffer, dictSize) );
CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loading a prefix clears the cdict : ", testNb++);
{
ZSTD_CDict* const cdict = ZSTD_createCDict(dictBuffer, dictSize, 1);
CHECK_Z( ZSTD_CCtx_refCDict(cctx, cdict) );
CHECK_Z( ZSTD_CCtx_refPrefix(cctx, (const char*)dictBuffer, dictSize) );
CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
ZSTD_freeCDict(cdict);
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loaded dictionary persists across reset session : ", testNb++);
{
size_t size1, size2;
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, CNBuffer, MIN(CNBuffSize, 10 KB)) );
size1 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
if (ZSTD_isError(size1)) goto _output_error;
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_only);
size2 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
if (ZSTD_isError(size2)) goto _output_error;
if (size1 != size2) goto _output_error;
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Loaded dictionary is cleared after resetting parameters : ", testNb++);
{
size_t size1, size2;
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, CNBuffer, MIN(CNBuffSize, 10 KB)) );
size1 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
if (ZSTD_isError(size1)) goto _output_error;
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
size2 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
if (ZSTD_isError(size2)) goto _output_error;
if (size1 == size2) goto _output_error;
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : Dictionary with non-default repcodes : ", testNb++);
{ U32 u; for (u=0; u<nbSamples; u++) samplesSizes[u] = sampleUnitSize; }
dictSize = ZDICT_trainFromBuffer(dictBuffer, dictSize,