From 2253d01b271ed7a0d2ebe359b55aa5b3c0f1f904 Mon Sep 17 00:00:00 2001 From: Nick Terrell Date: Wed, 28 Feb 2018 20:10:44 -0800 Subject: [PATCH] Move XXH64_update() into worker threads * Computes the XXH hash in the worker threads. * Workers get a sequence number and wait until ther number shows up. On error, ensures that its sequence is finished, so future threads don't get blocked. * Sets up for ldm integration, which will go in the same spot. --- lib/compress/zstdmt_compress.c | 117 ++++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 23 deletions(-) diff --git a/lib/compress/zstdmt_compress.c b/lib/compress/zstdmt_compress.c index 4b236900..e5bad421 100644 --- a/lib/compress/zstdmt_compress.c +++ b/lib/compress/zstdmt_compress.c @@ -304,16 +304,81 @@ static void ZSTDMT_releaseCCtx(ZSTDMT_CCtxPool* pool, ZSTD_CCtx* cctx) ZSTD_pthread_mutex_unlock(&pool->poolMutex); } - -/* ------------------------------------------ */ -/* ===== Worker thread ===== */ -/* ------------------------------------------ */ +/* ==== Serial State ==== */ typedef struct { void const* start; size_t size; } range_t; +typedef struct { + ZSTD_pthread_mutex_t mutex; + ZSTD_pthread_cond_t cond; + ZSTD_CCtx_params params; + XXH64_state_t xxhState; + unsigned nextJobID; +} serialState_t; + +static void ZSTDMT_serialState_reset(serialState_t* serialState, ZSTD_CCtx_params params) +{ + serialState->nextJobID = 0; + if (params.fParams.checksumFlag) + XXH64_reset(&serialState->xxhState, 0); + serialState->params = params; +} + +static int ZSTDMT_serialState_init(serialState_t* serialState) +{ + int initError = 0; + initError |= ZSTD_pthread_mutex_init(&serialState->mutex, NULL); + initError |= ZSTD_pthread_cond_init(&serialState->cond, NULL); + return initError; +} + +static void ZSTDMT_serialState_free(serialState_t* serialState) +{ + ZSTD_pthread_mutex_destroy(&serialState->mutex); + ZSTD_pthread_cond_destroy(&serialState->cond); +} + +static void ZSTDMT_serialState_update(serialState_t* serialState, range_t src, unsigned jobID) +{ + /* Wait for our turn */ + ZSTD_PTHREAD_MUTEX_LOCK(&serialState->mutex); + while (serialState->nextJobID < jobID) { + ZSTD_pthread_cond_wait(&serialState->cond, &serialState->mutex); + } + /* A future job may error and skip our job */ + if (serialState->nextJobID == jobID) { + /* It is now our turn, do any processing necessary */ + if (serialState->params.fParams.checksumFlag && src.size > 0) + XXH64_update(&serialState->xxhState, src.start, src.size); + } + /* Now it is the next jobs turn */ + serialState->nextJobID++; + ZSTD_pthread_cond_broadcast(&serialState->cond); + ZSTD_pthread_mutex_unlock(&serialState->mutex); +} + +static void ZSTDMT_serialState_ensureFinished(serialState_t* serialState, + unsigned jobID, size_t cSize) +{ + ZSTD_PTHREAD_MUTEX_LOCK(&serialState->mutex); + if (serialState->nextJobID <= jobID) { + assert(ZSTD_isError(cSize)); (void)cSize; + DEBUGLOG(5, "Skipping past job %u because of error", jobID); + serialState->nextJobID = jobID + 1; + ZSTD_pthread_cond_broadcast(&serialState->cond); + } + ZSTD_pthread_mutex_unlock(&serialState->mutex); + +} + + +/* ------------------------------------------ */ +/* ===== Worker thread ===== */ +/* ------------------------------------------ */ + static const range_t kNullRange = { NULL, 0 }; typedef struct { @@ -323,9 +388,11 @@ typedef struct { ZSTD_pthread_cond_t job_cond; /* Thread-safe - used by mtctx and worker */ ZSTDMT_CCtxPool* cctxPool; /* Thread-safe - used by mtctx and (all) workers */ ZSTDMT_bufferPool* bufPool; /* Thread-safe - used by mtctx and (all) workers */ + serialState_t* serial; /* Thread-safe - used by mtctx and (all) workers */ buffer_t dstBuff; /* set by worker (or mtctx), then read by worker & mtctx, then modified by mtctx => no barrier */ range_t prefix; /* set by mtctx, then read by worker & mtctx => no barrier */ range_t src; /* set by mtctx, then read by worker & mtctx => no barrier */ + unsigned jobID; /* set by mtctx, then read by worker => no barrier */ unsigned firstJob; /* set by mtctx, then read by worker => no barrier */ unsigned lastJob; /* set by mtctx, then read by worker => no barrier */ ZSTD_CCtx_params params; /* set by mtctx, then read by worker => no barrier */ @@ -339,9 +406,13 @@ typedef struct { void ZSTDMT_compressionJob(void* jobDescription) { ZSTDMT_jobDescription* const job = (ZSTDMT_jobDescription*)jobDescription; + ZSTD_CCtx_params jobParams = job->params; /* do not modify job->params ! copy it, modify the copy */ ZSTD_CCtx* const cctx = ZSTDMT_getCCtx(job->cctxPool); buffer_t dstBuff = job->dstBuff; + /* Don't compute the checksum for chunks, but write it in the header */ + if (job->jobID != 0) jobParams.fParams.checksumFlag = 0; + /* ressources */ if (cctx==NULL) { job->cSize = ERROR(memory_allocation); @@ -358,12 +429,11 @@ void ZSTDMT_compressionJob(void* jobDescription) /* init */ if (job->cdict) { - size_t const initError = ZSTD_compressBegin_advanced_internal(cctx, NULL, 0, ZSTD_dm_auto, job->cdict, job->params, job->fullFrameSize); + size_t const initError = ZSTD_compressBegin_advanced_internal(cctx, NULL, 0, ZSTD_dm_auto, job->cdict, jobParams, job->fullFrameSize); assert(job->firstJob); /* only allowed for first job */ if (ZSTD_isError(initError)) { job->cSize = initError; goto _endJob; } } else { /* srcStart points at reloaded section */ U64 const pledgedSrcSize = job->firstJob ? job->fullFrameSize : job->src.size; - ZSTD_CCtx_params jobParams = job->params; /* do not modify job->params ! copy it, modify the copy */ { size_t const forceWindowError = ZSTD_CCtxParam_setParameter(&jobParams, ZSTD_p_forceMaxWindow, !job->firstJob); if (ZSTD_isError(forceWindowError)) { job->cSize = forceWindowError; @@ -377,6 +447,10 @@ void ZSTDMT_compressionJob(void* jobDescription) job->cSize = initError; goto _endJob; } } } + + /* Perform serial step as early as possible */ + ZSTDMT_serialState_update(job->serial, job->src, job->jobID); + if (!job->firstJob) { /* flush and overwrite frame header when it's not first job */ size_t const hSize = ZSTD_compressContinue(cctx, dstBuff.start, dstBuff.capacity, job->src.start, 0); if (ZSTD_isError(hSize)) { job->cSize = hSize; /* save error code */ goto _endJob; } @@ -425,6 +499,7 @@ void ZSTDMT_compressionJob(void* jobDescription) } } _endJob: + ZSTDMT_serialState_ensureFinished(job->serial, job->jobID, job->cSize); if (job->prefix.size > 0) DEBUGLOG(5, "Finished with prefix: %zx", (size_t)job->prefix.start); DEBUGLOG(5, "Finished with source: %zx", (size_t)job->src.start); @@ -475,7 +550,7 @@ struct ZSTDMT_CCtx_s { roundBuff_t roundBuff; inBuff_t inBuff; int jobReady; /* 1 => one job is already prepared, but pool has shortage of workers. Don't create another one. */ - XXH64_state_t xxhState; + serialState_t serial; unsigned singleBlockingThread; unsigned jobIDMask; unsigned doneJobID; @@ -540,6 +615,7 @@ ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbWorkers, ZSTD_customMem cMem) { ZSTDMT_CCtx* mtctx; U32 nbJobs = nbWorkers + 2; + int initError; DEBUGLOG(3, "ZSTDMT_createCCtx_advanced (nbWorkers = %u)", nbWorkers); if (nbWorkers < 1) return NULL; @@ -559,8 +635,9 @@ ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbWorkers, ZSTD_customMem cMem) mtctx->jobIDMask = nbJobs - 1; mtctx->bufPool = ZSTDMT_createBufferPool(nbWorkers, cMem); mtctx->cctxPool = ZSTDMT_createCCtxPool(nbWorkers, cMem); + initError = ZSTDMT_serialState_init(&mtctx->serial); mtctx->roundBuff = kNullRoundBuff; - if (!mtctx->factory | !mtctx->jobs | !mtctx->bufPool | !mtctx->cctxPool) { + if (!mtctx->factory | !mtctx->jobs | !mtctx->bufPool | !mtctx->cctxPool | initError) { ZSTDMT_freeCCtx(mtctx); return NULL; } @@ -615,6 +692,7 @@ size_t ZSTDMT_freeCCtx(ZSTDMT_CCtx* mtctx) ZSTDMT_freeJobsTable(mtctx->jobs, mtctx->jobIDMask+1, mtctx->cMem); ZSTDMT_freeBufferPool(mtctx->bufPool); ZSTDMT_freeCCtxPool(mtctx->cctxPool); + ZSTDMT_serialState_free(&mtctx->serial); ZSTD_freeCDict(mtctx->cdictLocal); if (mtctx->roundBuff.buffer) ZSTD_free(mtctx->roundBuff.buffer, mtctx->cMem); @@ -779,7 +857,6 @@ static size_t ZSTDMT_compress_advanced_internal( size_t remainingSrcSize = srcSize; unsigned const compressWithinDst = (dstCapacity >= ZSTD_compressBound(srcSize)) ? nbJobs : (unsigned)(dstCapacity / ZSTD_compressBound(avgJobSize)); /* presumes avgJobSize >= 256 KB, which should be the case */ size_t frameStartPos = 0, dstBufferPos = 0; - XXH64_state_t xxh64; assert(jobParams.nbWorkers == 0); assert(mtctx->cctxPool->totalCCtx == params.nbWorkers); @@ -795,7 +872,7 @@ static size_t ZSTDMT_compress_advanced_internal( assert(avgJobSize >= 256 KB); /* condition for ZSTD_compressBound(A) + ZSTD_compressBound(B) <= ZSTD_compressBound(A+B), required to compress directly into Dst (no additional buffer) */ ZSTDMT_setBufferSize(mtctx->bufPool, ZSTD_compressBound(avgJobSize) ); - XXH64_reset(&xxh64, 0); + ZSTDMT_serialState_reset(&mtctx->serial, params); if (nbJobs > mtctx->jobIDMask+1) { /* enlarge job table */ U32 jobsTableSize = nbJobs; @@ -825,17 +902,14 @@ static size_t ZSTDMT_compress_advanced_internal( mtctx->jobs[u].fullFrameSize = srcSize; mtctx->jobs[u].params = jobParams; /* do not calculate checksum within sections, but write it in header for first section */ - if (u!=0) mtctx->jobs[u].params.fParams.checksumFlag = 0; mtctx->jobs[u].dstBuff = dstBuffer; mtctx->jobs[u].cctxPool = mtctx->cctxPool; mtctx->jobs[u].bufPool = mtctx->bufPool; + mtctx->jobs[u].serial = &mtctx->serial; + mtctx->jobs[u].jobID = u; mtctx->jobs[u].firstJob = (u==0); mtctx->jobs[u].lastJob = (u==nbJobs-1); - if (params.fParams.checksumFlag) { - XXH64_update(&xxh64, srcStart + frameStartPos, jobSize); - } - DEBUGLOG(5, "ZSTDMT_compress_advanced_internal: posting job %u (%u bytes)", u, (U32)jobSize); DEBUG_PRINTHEX(6, mtctx->jobs[u].prefix.start, 12); POOL_add(mtctx->factory, ZSTDMT_compressionJob, &mtctx->jobs[u]); @@ -876,7 +950,7 @@ static size_t ZSTDMT_compress_advanced_internal( DEBUGLOG(4, "checksumFlag : %u ", params.fParams.checksumFlag); if (params.fParams.checksumFlag) { - U32 const checksum = (U32)XXH64_digest(&xxh64); + U32 const checksum = (U32)XXH64_digest(&mtctx->serial.xxhState); if (dstPos + 4 > dstCapacity) { error = ERROR(dstSize_tooSmall); } else { @@ -1016,7 +1090,7 @@ size_t ZSTDMT_initCStream_internal( mtctx->allJobsCompleted = 0; mtctx->consumed = 0; mtctx->produced = 0; - if (params.fParams.checksumFlag) XXH64_reset(&mtctx->xxhState, 0); + ZSTDMT_serialState_reset(&mtctx->serial, params); return 0; } @@ -1113,21 +1187,18 @@ static size_t ZSTDMT_createCompressionJob(ZSTDMT_CCtx* mtctx, size_t srcSize, ZS mtctx->jobs[jobID].consumed = 0; mtctx->jobs[jobID].cSize = 0; mtctx->jobs[jobID].params = mtctx->params; - /* do not calculate checksum within sections, but write it in header for first section */ - if (mtctx->nextJobID) mtctx->jobs[jobID].params.fParams.checksumFlag = 0; mtctx->jobs[jobID].cdict = mtctx->nextJobID==0 ? mtctx->cdict : NULL; mtctx->jobs[jobID].fullFrameSize = mtctx->frameContentSize; mtctx->jobs[jobID].dstBuff = g_nullBuffer; mtctx->jobs[jobID].cctxPool = mtctx->cctxPool; mtctx->jobs[jobID].bufPool = mtctx->bufPool; + mtctx->jobs[jobID].serial = &mtctx->serial; + mtctx->jobs[jobID].jobID = mtctx->nextJobID; mtctx->jobs[jobID].firstJob = (mtctx->nextJobID==0); mtctx->jobs[jobID].lastJob = endFrame; mtctx->jobs[jobID].frameChecksumNeeded = endFrame && (mtctx->nextJobID>0) && mtctx->params.fParams.checksumFlag; mtctx->jobs[jobID].dstFlushed = 0; - if (mtctx->params.fParams.checksumFlag && srcSize > 0) - XXH64_update(&mtctx->xxhState, src, srcSize); - /* Update the round buffer pos and clear the input buffer to be reset */ mtctx->roundBuff.pos += srcSize; mtctx->inBuff.buffer = g_nullBuffer; @@ -1214,7 +1285,7 @@ static size_t ZSTDMT_flushProduced(ZSTDMT_CCtx* mtctx, ZSTD_outBuffer* output, u assert(srcConsumed <= srcSize); if ( (srcConsumed == srcSize) /* job completed -> worker no longer active */ && mtctx->jobs[wJobID].frameChecksumNeeded ) { - U32 const checksum = (U32)XXH64_digest(&mtctx->xxhState); + U32 const checksum = (U32)XXH64_digest(&mtctx->serial.xxhState); DEBUGLOG(4, "ZSTDMT_flushProduced: writing checksum : %08X \n", checksum); MEM_writeLE32((char*)mtctx->jobs[wJobID].dstBuff.start + mtctx->jobs[wJobID].cSize, checksum); cSize += 4;