diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index c071c573..d268cef0 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -3645,13 +3645,9 @@ size_t ZSTD_compress_generic (ZSTD_CCtx* cctx, } if (params.nbWorkers > 0) { /* mt context creation */ - if (cctx->mtctx == NULL || (params.nbWorkers != ZSTDMT_getNbWorkers(cctx->mtctx))) { + if (cctx->mtctx == NULL) { DEBUGLOG(4, "ZSTD_compress_generic: creating new mtctx for nbWorkers=%u", params.nbWorkers); - if (cctx->mtctx != NULL) - DEBUGLOG(4, "ZSTD_compress_generic: previous nbWorkers was %u", - ZSTDMT_getNbWorkers(cctx->mtctx)); - ZSTDMT_freeCCtx(cctx->mtctx); cctx->mtctx = ZSTDMT_createCCtx_advanced(params.nbWorkers, cctx->customMem); if (cctx->mtctx == NULL) return ERROR(memory_allocation); } diff --git a/lib/compress/zstdmt_compress.c b/lib/compress/zstdmt_compress.c index 0c59c6a7..4a2a4cdd 100644 --- a/lib/compress/zstdmt_compress.c +++ b/lib/compress/zstdmt_compress.c @@ -159,6 +159,25 @@ static void ZSTDMT_setBufferSize(ZSTDMT_bufferPool* const bufPool, size_t const ZSTD_pthread_mutex_unlock(&bufPool->poolMutex); } + +static ZSTDMT_bufferPool* ZSTDMT_expandBufferPool(ZSTDMT_bufferPool* srcBufPool, U32 nbWorkers) +{ + unsigned const maxNbBuffers = 2*nbWorkers + 3; + if (srcBufPool==NULL) return NULL; + if (srcBufPool->totalBuffers >= maxNbBuffers) /* good enough */ + return srcBufPool; + /* need a larger buffer pool */ + { ZSTD_customMem const cMem = srcBufPool->cMem; + size_t const bSize = srcBufPool->bufferSize; /* forward parameters */ + ZSTDMT_bufferPool* newBufPool; + ZSTDMT_freeBufferPool(srcBufPool); + newBufPool = ZSTDMT_createBufferPool(nbWorkers, cMem); + if (newBufPool==NULL) return newBufPool; + ZSTDMT_setBufferSize(newBufPool, bSize); + return newBufPool; + } +} + /** ZSTDMT_getBuffer() : * assumption : bufPool must be valid * @return : a buffer, with start pointer and size @@ -309,6 +328,10 @@ static void ZSTDMT_freeSeqPool(ZSTDMT_seqPool* seqPool) ZSTDMT_freeBufferPool(seqPool); } +static ZSTDMT_seqPool* ZSTDMT_expandSeqPool(ZSTDMT_seqPool* pool, U32 nbWorkers) +{ + return ZSTDMT_expandBufferPool(pool, nbWorkers); +} /* ===== CCtx Pool ===== */ @@ -354,6 +377,18 @@ static ZSTDMT_CCtxPool* ZSTDMT_createCCtxPool(unsigned nbWorkers, return cctxPool; } +static ZSTDMT_CCtxPool* ZSTDMT_expandCCtxPool(ZSTDMT_CCtxPool* srcPool, + unsigned nbWorkers) +{ + if (srcPool==NULL) return NULL; + if (nbWorkers <= srcPool->totalCCtx) return srcPool; /* good enough */ + /* need a larger cctx pool */ + { ZSTD_customMem const cMem = srcPool->cMem; + ZSTDMT_freeCCtxPool(srcPool); + return ZSTDMT_createCCtxPool(nbWorkers, cMem); + } +} + /* only works during initialization phase, not during compression */ static size_t ZSTDMT_sizeof_CCtxPool(ZSTDMT_CCtxPool* cctxPool) { @@ -745,9 +780,9 @@ struct ZSTDMT_CCtx_s { ZSTD_CCtx_params params; size_t targetSectionSize; size_t targetPrefixSize; - roundBuff_t roundBuff; + int jobReady; /* 1 => one job is already prepared, but pool has shortage of workers. Don't create a new job. */ inBuff_t inBuff; - int jobReady; /* 1 => one job is already prepared, but pool has shortage of workers. Don't create another one. */ + roundBuff_t roundBuff; serialState_t serial; unsigned singleBlockingThread; unsigned jobIDMask; @@ -798,6 +833,20 @@ static ZSTDMT_jobDescription* ZSTDMT_createJobsTable(U32* nbJobsPtr, ZSTD_custom return jobTable; } +static size_t ZSTDMT_expandJobsTable (ZSTDMT_CCtx* mtctx, U32 nbWorkers) { + U32 nbJobs = nbWorkers + 2; + if (nbJobs > mtctx->jobIDMask+1) { /* need more job capacity */ + ZSTDMT_freeJobsTable(mtctx->jobs, mtctx->jobIDMask+1, mtctx->cMem); + mtctx->jobIDMask = 0; + mtctx->jobs = ZSTDMT_createJobsTable(&nbJobs, mtctx->cMem); + if (mtctx->jobs==NULL) return ERROR(memory_allocation); + assert((nbJobs != 0) && ((nbJobs & (nbJobs - 1)) == 0)); /* ensure nbJobs is a power of 2 */ + mtctx->jobIDMask = nbJobs - 1; + } + return 0; +} + + /* ZSTDMT_CCtxParam_setNbWorkers(): * Internal use only */ size_t ZSTDMT_CCtxParam_setNbWorkers(ZSTD_CCtx_params* params, unsigned nbWorkers) @@ -964,6 +1013,25 @@ static ZSTD_CCtx_params ZSTDMT_initJobCCtxParams(ZSTD_CCtx_params const params) return jobParams; } + +/* ZSTDMT_resize() : + * @return : error code if fails, 0 on success */ +static size_t ZSTDMT_resize(ZSTDMT_CCtx* mtctx, unsigned nbWorkers) +{ + mtctx->factory = POOL_resize(mtctx->factory, nbWorkers); + if (mtctx->factory == NULL) return ERROR(memory_allocation); + CHECK_F( ZSTDMT_expandJobsTable(mtctx, nbWorkers) ); + mtctx->bufPool = ZSTDMT_expandBufferPool(mtctx->bufPool, nbWorkers); + if (mtctx->bufPool == NULL) return ERROR(memory_allocation); + mtctx->cctxPool = ZSTDMT_expandCCtxPool(mtctx->cctxPool, nbWorkers); + if (mtctx->cctxPool == NULL) return ERROR(memory_allocation); + mtctx->seqPool = ZSTDMT_expandSeqPool(mtctx->seqPool, nbWorkers); + if (mtctx->seqPool == NULL) return ERROR(memory_allocation); + ZSTDMT_CCtxParam_setNbWorkers(&mtctx->params, nbWorkers); + return 0; +} + + /*! ZSTDMT_updateCParams_whileCompressing() : * Updates only a selected set of compression parameters, to remain compatible with current frame. * New parameters will be applied to next compression job. */ @@ -980,15 +1048,6 @@ void ZSTDMT_updateCParams_whileCompressing(ZSTDMT_CCtx* mtctx, const ZSTD_CCtx_p } } -/* ZSTDMT_getNbWorkers(): - * @return nb threads currently active in mtctx. - * mtctx must be valid */ -unsigned ZSTDMT_getNbWorkers(const ZSTDMT_CCtx* mtctx) -{ - assert(mtctx != NULL); - return mtctx->params.nbWorkers; -} - /* ZSTDMT_getFrameProgression(): * tells how much data has been consumed (input) and produced (output) for current frame. * able to count progression inside worker threads. @@ -1089,15 +1148,7 @@ static size_t ZSTDMT_compress_advanced_internal( if (ZSTDMT_serialState_reset(&mtctx->serial, mtctx->seqPool, params)) return ERROR(memory_allocation); - if (nbJobs > mtctx->jobIDMask+1) { /* enlarge job table */ - U32 jobsTableSize = nbJobs; - ZSTDMT_freeJobsTable(mtctx->jobs, mtctx->jobIDMask+1, mtctx->cMem); - mtctx->jobIDMask = 0; - mtctx->jobs = ZSTDMT_createJobsTable(&jobsTableSize, mtctx->cMem); - if (mtctx->jobs==NULL) return ERROR(memory_allocation); - assert((jobsTableSize != 0) && ((jobsTableSize & (jobsTableSize - 1)) == 0)); /* ensure jobsTableSize is a power of 2 */ - mtctx->jobIDMask = jobsTableSize - 1; - } + CHECK_F( ZSTDMT_expandJobsTable(mtctx, nbJobs) ); /* only expands if necessary */ { unsigned u; for (u=0; ucctxPool->totalCCtx); - /* params are supposed to be fully validated at this point */ + + /* params supposed partially fully validated at this point */ assert(!ZSTD_isError(ZSTD_checkCParams(params.cParams))); assert(!((dict) && (cdict))); /* either dict or cdict, not both */ - assert(mtctx->cctxPool->totalCCtx == params.nbWorkers); /* init */ + if (params.nbWorkers != mtctx->params.nbWorkers) + ZSTDMT_resize(mtctx, params.nbWorkers); + if (params.jobSize == 0) { params.jobSize = 1U << ZSTDMT_computeTargetJobLog(params); } diff --git a/lib/compress/zstdmt_compress.h b/lib/compress/zstdmt_compress.h index f79e3b44..4249a82d 100644 --- a/lib/compress/zstdmt_compress.h +++ b/lib/compress/zstdmt_compress.h @@ -126,11 +126,6 @@ size_t ZSTDMT_CCtxParam_setNbWorkers(ZSTD_CCtx_params* params, unsigned nbWorker * New parameters will be applied to next compression job. */ void ZSTDMT_updateCParams_whileCompressing(ZSTDMT_CCtx* mtctx, const ZSTD_CCtx_params* cctxParams); -/* ZSTDMT_getNbWorkers(): - * @return nb threads currently active in mtctx. - * mtctx must be valid */ -unsigned ZSTDMT_getNbWorkers(const ZSTDMT_CCtx* mtctx); - /* ZSTDMT_getFrameProgression(): * tells how much data has been consumed (input) and produced (output) for current frame. * able to count progression inside worker threads.