added code to generate dictionary using finalizeDictionary

This commit is contained in:
Paul Cruz 2017-06-13 11:54:43 -07:00
parent 11c3987baf
commit f35f252e36

View File

@ -18,6 +18,7 @@
#include "zstd.h" #include "zstd.h"
#include "zstd_internal.h" #include "zstd_internal.h"
#include "mem.h" #include "mem.h"
#include "zdict.h"
// Direct access to internal compression functions is required // Direct access to internal compression functions is required
#include "zstd_compress.c" #include "zstd_compress.c"
@ -316,7 +317,8 @@ static void writeFrameHeader(U32* seed, frame_t* frame, int genDict, size_t dict
op[pos++] = windowByte; op[pos++] = windowByte;
} }
if(genDict) { if(genDict) {
MEM_writeLE32(op + pos, (U32) dictSize); MEM_writeLE32(op + pos, (U32) dictID);
pos += 4;
} }
if (contentSizeFlag) { if (contentSizeFlag) {
switch (fcsCode) { switch (fcsCode) {
@ -608,7 +610,7 @@ static inline void initSeqStore(seqStore_t *seqStore) {
/* Randomly generate sequence commands */ /* Randomly generate sequence commands */
static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore, static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore,
size_t contentSize, size_t literalsSize, int genDict, size_t dictSize) size_t contentSize, size_t literalsSize, int genDict, size_t dictSize, BYTE* dictContent)
{ {
/* The total length of all the matches */ /* The total length of all the matches */
size_t const remainingMatch = contentSize - literalsSize; size_t const remainingMatch = contentSize - literalsSize;
@ -686,11 +688,17 @@ static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore,
repIndex = MIN(2, offsetCode + 1); repIndex = MIN(2, offsetCode + 1);
} }
} }
} while (offset > (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart) || offset == 0); } while (((!genDict) && (offset > (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart))) || offset == 0);
{ size_t j; { size_t j;
for (j = 0; j < matchLen; j++) { for (j = 0; j < matchLen; j++) {
*srcPtr = *(srcPtr-offset); if(srcPtr-offset < frame->srcStart){
/* copy from dictionary instead of literals */
*srcPtr = *(dictContent + dictSize - (offset-(srcPtr-frame->srcStart)));
}
else{
*srcPtr = *(srcPtr-offset);
}
srcPtr++; srcPtr++;
} }
} }
@ -940,7 +948,7 @@ static size_t writeSequences(U32* seed, frame_t* frame, seqStore_t* seqStorePtr,
} }
static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize, static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize,
size_t literalsSize, int genDict, size_t dictSize) size_t literalsSize, int genDict, size_t dictSize, BYTE* dictContent)
{ {
seqStore_t seqStore; seqStore_t seqStore;
size_t numSequences; size_t numSequences;
@ -949,14 +957,14 @@ static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize,
initSeqStore(&seqStore); initSeqStore(&seqStore);
/* randomly generate sequences */ /* randomly generate sequences */
numSequences = generateSequences(seed, frame, &seqStore, contentSize, literalsSize, genDict, dictSize); numSequences = generateSequences(seed, frame, &seqStore, contentSize, literalsSize, genDict, dictSize, dictContent);
/* write them out to the frame data */ /* write them out to the frame data */
CHECKERR(writeSequences(seed, frame, &seqStore, numSequences)); CHECKERR(writeSequences(seed, frame, &seqStore, numSequences));
return numSequences; return numSequences;
} }
static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize, int genDict, size_t dictSize) static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize, int genDict, size_t dictSize, BYTE* dictContent)
{ {
BYTE* const blockStart = (BYTE*)frame->data; BYTE* const blockStart = (BYTE*)frame->data;
size_t literalsSize; size_t literalsSize;
@ -968,7 +976,7 @@ static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize
DISPLAYLEVEL(4, " literals size: %u\n", (U32)literalsSize); DISPLAYLEVEL(4, " literals size: %u\n", (U32)literalsSize);
nbSeq = writeSequencesBlock(seed, frame, contentSize, literalsSize, genDict, dictSize); nbSeq = writeSequencesBlock(seed, frame, contentSize, literalsSize, genDict, dictSize, dictContent);
DISPLAYLEVEL(4, " number of sequences: %u\n", (U32)nbSeq); DISPLAYLEVEL(4, " number of sequences: %u\n", (U32)nbSeq);
@ -976,7 +984,7 @@ static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize
} }
static void writeBlock(U32* seed, frame_t* frame, size_t contentSize, static void writeBlock(U32* seed, frame_t* frame, size_t contentSize,
int lastBlock, int genDict, size_t dictSize) int lastBlock, int genDict, size_t dictSize, BYTE* dictContent)
{ {
int const blockTypeDesc = RAND(seed) % 8; int const blockTypeDesc = RAND(seed) % 8;
size_t blockSize; size_t blockSize;
@ -1016,7 +1024,7 @@ static void writeBlock(U32* seed, frame_t* frame, size_t contentSize,
frame->oldStats = frame->stats; frame->oldStats = frame->stats;
frame->data = op; frame->data = op;
compressedSize = writeCompressedBlock(seed, frame, contentSize, genDict, dictSize); compressedSize = writeCompressedBlock(seed, frame, contentSize, genDict, dictSize, dictContent);
if (compressedSize > contentSize) { if (compressedSize > contentSize) {
blockType = 0; blockType = 0;
memcpy(op, frame->src, contentSize); memcpy(op, frame->src, contentSize);
@ -1042,7 +1050,7 @@ static void writeBlock(U32* seed, frame_t* frame, size_t contentSize,
frame->data = op; frame->data = op;
} }
static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize) static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize, BYTE* dictContent)
{ {
size_t contentLeft = frame->header.contentSize; size_t contentLeft = frame->header.contentSize;
size_t const maxBlockSize = MIN(MAX_BLOCK_SIZE, frame->header.windowSize); size_t const maxBlockSize = MIN(MAX_BLOCK_SIZE, frame->header.windowSize);
@ -1065,7 +1073,7 @@ static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize)
} }
} }
writeBlock(seed, frame, blockContentSize, lastBlock, genDict, dictSize); writeBlock(seed, frame, blockContentSize, lastBlock, genDict, dictSize, dictContent);
contentLeft -= blockContentSize; contentLeft -= blockContentSize;
if (lastBlock) break; if (lastBlock) break;
@ -1130,14 +1138,14 @@ static void initFrame(frame_t* fr)
} }
/* Return the final seed */ /* Return the final seed */
static U32 generateFrame(U32 seed, frame_t* fr, int genDict, size_t dictSize) static U32 generateFrame(U32 seed, frame_t* fr, int genDict, size_t dictSize, BYTE* dictContent)
{ {
/* generate a complete frame */ /* generate a complete frame */
DISPLAYLEVEL(1, "frame seed: %u\n", seed); DISPLAYLEVEL(1, "frame seed: %u\n", seed);
initFrame(fr); initFrame(fr);
writeFrameHeader(&seed, fr, genDict, dictSize); writeFrameHeader(&seed, fr, genDict, dictSize);
writeBlocks(&seed, fr, genDict, dictSize); writeBlocks(&seed, fr, genDict, dictSize, dictContent);
writeChecksum(fr); writeChecksum(fr);
return seed; return seed;
@ -1224,7 +1232,7 @@ static int runTestMode(U32 seed, unsigned numFiles, unsigned const testDurationS
else else
DISPLAYUPDATE("\r%u ", fnum); DISPLAYUPDATE("\r%u ", fnum);
seed = generateFrame(seed, &fr, 0, 0); seed = generateFrame(seed, &fr, 0, 0, NULL);
{ size_t const r = testDecodeSimple(&fr); { size_t const r = testDecodeSimple(&fr);
if (ZSTD_isError(r)) { if (ZSTD_isError(r)) {
@ -1259,7 +1267,7 @@ static int generateFile(U32 seed, const char* const path,
DISPLAY("seed: %u\n", seed); DISPLAY("seed: %u\n", seed);
generateFrame(seed, &fr, 0, 0); generateFrame(seed, &fr, 0, 0, NULL);
outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, path); outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, path);
if (origPath) { if (origPath) {
@ -1281,7 +1289,7 @@ static int generateCorpus(U32 seed, unsigned numFiles, const char* const path,
DISPLAYUPDATE("\r%u/%u ", fnum, numFiles); DISPLAYUPDATE("\r%u/%u ", fnum, numFiles);
seed = generateFrame(seed, &fr, 0, 0); seed = generateFrame(seed, &fr, 0, 0, NULL);
if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) { if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) {
DISPLAY("Error: path too long\n"); DISPLAY("Error: path too long\n");
@ -1308,9 +1316,11 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const
{ {
const size_t minDictSize = 8; const size_t minDictSize = 8;
char outPath[MAX_PATH]; char outPath[MAX_PATH];
BYTE* dictContent;
BYTE* fullDict;
U32 dictID; U32 dictID;
BYTE* dictStart;
unsigned fnum; unsigned fnum;
BYTE* decompressedPtr;
ZSTD_DCtx* dctx = ZSTD_createDCtx(); ZSTD_DCtx* dctx = ZSTD_createDCtx();
if(snprintf(outPath, MAX_PATH, "%s/dictionary", path) + 1 > MAX_PATH) { if(snprintf(outPath, MAX_PATH, "%s/dictionary", path) + 1 > MAX_PATH) {
DISPLAY("Error: path too long\n"); DISPLAY("Error: path too long\n");
@ -1318,37 +1328,50 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const
} }
/* Generate the dictionary randomly first */ /* Generate the dictionary randomly first */
if(dictSize < minDictSize){ dictContent = malloc(dictSize-400);
DISPLAY("Error: dictionary size (%zu) is too small\n", dictSize); dictID = RAND(&seed);
} fullDict = malloc(dictSize);
else{ RAND_buffer(&seed, dictContent, dictSize-40);
/* variable declaration */ {
dictStart = malloc(dictSize); /* create random samples */
size_t pos = 0; unsigned numSamples = RAND(&seed);
dictID = RAND(&seed) + 1; unsigned i = 0;
size_t* sampleSizes = malloc(numSamples*sizeof(size_t));
size_t* curr = sampleSizes;
size_t totalSize = 0;
while(i < numSamples){
*curr = RAND(&seed) % (4 << 20);
totalSize += *curr;
curr++;
}
ZDICT_params_t zdictParams;
BYTE* samples = malloc(totalSize);
RAND_buffer(&seed, samples, totalSize);
/* write dictionary magic number */ /* set dictionary params */
MEM_writeLE32(dictStart + pos, ZSTD_DICT_MAGIC); memset(&zdictParams, 0, sizeof(zdictParams));
pos += 4; zdictParams.notificationLevel = 1;
zdictParams.dictID = dictID;
zdictParams.compressionLevel = 5;
/* write random dictionary ID */ /* finalize dictionary with random samples */
MEM_writeLE32(dictStart + pos, dictID); ZDICT_finalizeDictionary(fullDict, dictSize,
pos += 4; dictContent, dictSize-400,
samples, sampleSizes, numSamples,
/* randomly generate the rest of the dictionary */ zdictParams);
RAND_buffer(&seed, dictStart + pos, dictSize-8);
outputBuffer(dictStart, dictSize, outPath);
} }
decompressedPtr = malloc(MAX_DECOMPRESSED_SIZE);
/* generate random compressed/decompressed files */ /* generate random compressed/decompressed files */
for (fnum = 0; fnum < numFiles; fnum++) { for (fnum = 0; fnum < numFiles; fnum++) {
frame_t fr; frame_t fr;
size_t returnValue; size_t returnValue;
BYTE* decompressedPtr = malloc(MAX_DECOMPRESSED_SIZE);
DISPLAYUPDATE("\r%u/%u ", fnum, numFiles); DISPLAYUPDATE("\r%u/%u ", fnum, numFiles);
seed = generateFrame(seed, &fr, 1, dictSize); seed = generateFrame(seed, &fr, 1, dictSize, dictContent);
if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) { if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) {
DISPLAY("Error: path too long\n"); DISPLAY("Error: path too long\n");
@ -1368,13 +1391,10 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const
returnValue = ZSTD_decompress_usingDict(dctx, decompressedPtr, MAX_DECOMPRESSED_SIZE, returnValue = ZSTD_decompress_usingDict(dctx, decompressedPtr, MAX_DECOMPRESSED_SIZE,
fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart,
dictStart,dictSize); fullDict, dictSize);
} }
/* write uncompressed versions of files */
DISPLAY("This is origPath: %s\nAnd this is numFiles: %d\n", origPath, numFiles);
return 0; return 0;
} }