diff --git a/doc/zstd_manual.html b/doc/zstd_manual.html index 360a6e77..2c66409c 100644 --- a/doc/zstd_manual.html +++ b/doc/zstd_manual.html @@ -338,6 +338,8 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB

Custom memory allocation functions

typedef void* (*ZSTD_allocFunction) (void* opaque, size_t size);
 typedef void  (*ZSTD_freeFunction) (void* opaque, void* address);
 typedef struct { ZSTD_allocFunction customAlloc; ZSTD_freeFunction customFree; void* opaque; } ZSTD_customMem;
+/* use this constant to defer to stdlib's functions */
+static const ZSTD_customMem ZSTD_defaultCMem = { NULL, NULL, NULL};
 

Frame size functions


 
diff --git a/lib/common/zstd_common.c b/lib/common/zstd_common.c
index 5e207e7d..9b46df1d 100644
--- a/lib/common/zstd_common.c
+++ b/lib/common/zstd_common.c
@@ -12,7 +12,8 @@
 /*-*************************************
 *  Dependencies
 ***************************************/
-#include          /* malloc */
+#include       /* malloc, calloc, free */
+#include       /* memset */
 #include "error_private.h"
 #define ZSTD_STATIC_LINKING_ONLY
 #include "zstd.h"
@@ -65,11 +66,29 @@ void ZSTD_defaultFreeFunction(void* opaque, void* address)
 
 void* ZSTD_malloc(size_t size, ZSTD_customMem customMem)
 {
-    return customMem.customAlloc(customMem.opaque, size);
+    if (customMem.customAlloc)
+        return customMem.customAlloc(customMem.opaque, size);
+    return malloc(size);
+}
+
+void* ZSTD_calloc(size_t size, ZSTD_customMem customMem)
+{
+    if (customMem.customAlloc) {
+        /* calloc implemented as malloc+memset;
+         * not as efficient, but our best guess for custom malloc */
+        void* const ptr = customMem.customAlloc(customMem.opaque, size);
+        memset(ptr, 0, size);
+        return ptr;
+    }
+    return calloc(1, size);
 }
 
 void ZSTD_free(void* ptr, ZSTD_customMem customMem)
 {
-    if (ptr!=NULL)
-        customMem.customFree(customMem.opaque, ptr);
+    if (ptr!=NULL) {
+        if (customMem.customFree)
+            customMem.customFree(customMem.opaque, ptr);
+        else
+            free(ptr);
+    }
 }
diff --git a/lib/common/zstd_internal.h b/lib/common/zstd_internal.h
index 2533333b..5ab4aaef 100644
--- a/lib/common/zstd_internal.h
+++ b/lib/common/zstd_internal.h
@@ -244,6 +244,7 @@ void ZSTD_defaultFreeFunction(void* opaque, void* address);
 static const ZSTD_customMem defaultCustomMem = { ZSTD_defaultAllocFunction, ZSTD_defaultFreeFunction, NULL };
 #endif
 void* ZSTD_malloc(size_t size, ZSTD_customMem customMem);
+void* ZSTD_calloc(size_t size, ZSTD_customMem customMem);
 void ZSTD_free(void* ptr, ZSTD_customMem customMem);
 
 
diff --git a/lib/compress/zstdmt_compress.c b/lib/compress/zstdmt_compress.c
index 0c6bc724..bd3de458 100644
--- a/lib/compress/zstdmt_compress.c
+++ b/lib/compress/zstdmt_compress.c
@@ -19,29 +19,33 @@
 
 
 /* ======   Dependencies   ====== */
-#include    /* malloc */
-#include    /* memcpy */
-#include "pool.h"     /* threadpool */
+#include      /* memcpy, memset */
+#include "pool.h"       /* threadpool */
 #include "threading.h"  /* mutex */
-#include "zstd_internal.h"   /* MIN, ERROR, ZSTD_*, ZSTD_highbit32 */
+#include "zstd_internal.h"  /* MIN, ERROR, ZSTD_*, ZSTD_highbit32 */
 #include "zstdmt_compress.h"
 
 
 /* ======   Debug   ====== */
-#if 0
+#if defined(ZSTDMT_DEBUG)
 
 #  include 
 #  include 
 #  include 
-   static unsigned g_debugLevel = 5;
-#  define DEBUGLOGRAW(l, ...) if (l<=g_debugLevel) { fprintf(stderr, __VA_ARGS__); }
-#  define DEBUGLOG(l, ...) if (l<=g_debugLevel) { fprintf(stderr, __FILE__ ": "); fprintf(stderr, __VA_ARGS__); fprintf(stderr, " \n"); }
+#  define DEBUGLOGRAW(l, ...) if (l<=ZSTDMT_DEBUG) { fprintf(stderr, __VA_ARGS__); }
+#  define DEBUGLOG(l, ...) {            \
+    if (l<=ZSTDMT_DEBUG) {              \
+        fprintf(stderr, __FILE__ ": "); \
+        fprintf(stderr, __VA_ARGS__);   \
+        fprintf(stderr, " \n");         \
+    }                                   \
+}
 
-#  define DEBUG_PRINTHEX(l,p,n) { \
-    unsigned debug_u;                   \
-    for (debug_u=0; debug_u<(n); debug_u++)           \
+#  define DEBUG_PRINTHEX(l,p,n) {            \
+    unsigned debug_u;                        \
+    for (debug_u=0; debug_u<(n); debug_u++)  \
         DEBUGLOGRAW(l, "%02X ", ((const unsigned char*)(p))[debug_u]); \
-    DEBUGLOGRAW(l, " \n");       \
+    DEBUGLOGRAW(l, " \n");                   \
 }
 
 static unsigned long long GetCurrentClockTimeMicroseconds(void)
@@ -54,17 +58,18 @@ static unsigned long long GetCurrentClockTimeMicroseconds(void)
 }
 
 #define MUTEX_WAIT_TIME_DLEVEL 5
-#define PTHREAD_MUTEX_LOCK(mutex) \
-if (g_debugLevel>=MUTEX_WAIT_TIME_DLEVEL) { \
-    unsigned long long const beforeTime = GetCurrentClockTimeMicroseconds(); \
-    pthread_mutex_lock(mutex); \
-    {   unsigned long long const afterTime = GetCurrentClockTimeMicroseconds(); \
-        unsigned long long const elapsedTime = (afterTime-beforeTime); \
-        if (elapsedTime > 1000) {  /* or whatever threshold you like; I'm using 1 millisecond here */ \
-            DEBUGLOG(MUTEX_WAIT_TIME_DLEVEL, "Thread took %llu microseconds to acquire mutex %s \n", \
-               elapsedTime, #mutex); \
-    }   } \
-} else pthread_mutex_lock(mutex);
+#define PTHREAD_MUTEX_LOCK(mutex) {               \
+    if (ZSTDMT_DEBUG>=MUTEX_WAIT_TIME_DLEVEL) {   \
+        unsigned long long const beforeTime = GetCurrentClockTimeMicroseconds(); \
+        pthread_mutex_lock(mutex);                \
+        {   unsigned long long const afterTime = GetCurrentClockTimeMicroseconds(); \
+            unsigned long long const elapsedTime = (afterTime-beforeTime); \
+            if (elapsedTime > 1000) {  /* or whatever threshold you like; I'm using 1 millisecond here */ \
+                DEBUGLOG(MUTEX_WAIT_TIME_DLEVEL, "Thread took %llu microseconds to acquire mutex %s \n", \
+                   elapsedTime, #mutex);          \
+        }   }                                     \
+    } else pthread_mutex_lock(mutex);             \
+}
 
 #else
 
@@ -87,16 +92,19 @@ static const buffer_t g_nullBuffer = { NULL, 0 };
 typedef struct ZSTDMT_bufferPool_s {
     unsigned totalBuffers;
     unsigned nbBuffers;
+    ZSTD_customMem cMem;
     buffer_t bTable[1];   /* variable size */
 } ZSTDMT_bufferPool;
 
-static ZSTDMT_bufferPool* ZSTDMT_createBufferPool(unsigned nbThreads)
+static ZSTDMT_bufferPool* ZSTDMT_createBufferPool(unsigned nbThreads, ZSTD_customMem cMem)
 {
     unsigned const maxNbBuffers = 2*nbThreads + 2;
-    ZSTDMT_bufferPool* const bufPool = (ZSTDMT_bufferPool*)calloc(1, sizeof(ZSTDMT_bufferPool) + (maxNbBuffers-1) * sizeof(buffer_t));
+    ZSTDMT_bufferPool* const bufPool = (ZSTDMT_bufferPool*)ZSTD_calloc(
+        sizeof(ZSTDMT_bufferPool) + (maxNbBuffers-1) * sizeof(buffer_t), cMem);
     if (bufPool==NULL) return NULL;
     bufPool->totalBuffers = maxNbBuffers;
     bufPool->nbBuffers = 0;
+    bufPool->cMem = cMem;
     return bufPool;
 }
 
@@ -105,8 +113,8 @@ static void ZSTDMT_freeBufferPool(ZSTDMT_bufferPool* bufPool)
     unsigned u;
     if (!bufPool) return;   /* compatibility with free on NULL */
     for (u=0; utotalBuffers; u++)
-        free(bufPool->bTable[u].start);
-    free(bufPool);
+        ZSTD_free(bufPool->bTable[u].start, bufPool->cMem);
+    ZSTD_free(bufPool, bufPool->cMem);
 }
 
 /* assumption : invocation from main thread only ! */
@@ -115,13 +123,15 @@ static buffer_t ZSTDMT_getBuffer(ZSTDMT_bufferPool* pool, size_t bSize)
     if (pool->nbBuffers) {   /* try to use an existing buffer */
         buffer_t const buf = pool->bTable[--(pool->nbBuffers)];
         size_t const availBufferSize = buf.size;
-        if ((availBufferSize >= bSize) & (availBufferSize <= 10*bSize))   /* large enough, but not too much */
+        if ((availBufferSize >= bSize) & (availBufferSize <= 10*bSize))
+            /* large enough, but not too much */
             return buf;
-        free(buf.start);   /* size conditions not respected : scratch this buffer and create a new one */
+        /* size conditions not respected : scratch this buffer, create new one */
+        ZSTD_free(buf.start, pool->cMem);
     }
     /* create new buffer */
     {   buffer_t buffer;
-        void* const start = malloc(bSize);
+        void* const start = ZSTD_malloc(bSize, pool->cMem);
         if (start==NULL) bSize = 0;
         buffer.start = start;   /* note : start can be NULL if malloc fails ! */
         buffer.size = bSize;
@@ -138,7 +148,7 @@ static void ZSTDMT_releaseBuffer(ZSTDMT_bufferPool* pool, buffer_t buf)
         return;
     }
     /* Reached bufferPool capacity (should not happen) */
-    free(buf.start);
+    ZSTD_free(buf.start, pool->cMem);
 }
 
 
@@ -147,6 +157,7 @@ static void ZSTDMT_releaseBuffer(ZSTDMT_bufferPool* pool, buffer_t buf)
 typedef struct {
     unsigned totalCCtx;
     unsigned availCCtx;
+    ZSTD_customMem cMem;
     ZSTD_CCtx* cctx[1];   /* variable size */
 } ZSTDMT_CCtxPool;
 
@@ -158,18 +169,21 @@ static void ZSTDMT_freeCCtxPool(ZSTDMT_CCtxPool* pool)
     unsigned u;
     for (u=0; utotalCCtx; u++)
         ZSTD_freeCCtx(pool->cctx[u]);  /* note : compatible with free on NULL */
-    free(pool);
+    ZSTD_free(pool, pool->cMem);
 }
 
 /* ZSTDMT_createCCtxPool() :
  * implies nbThreads >= 1 , checked by caller ZSTDMT_createCCtx() */
-static ZSTDMT_CCtxPool* ZSTDMT_createCCtxPool(unsigned nbThreads)
+static ZSTDMT_CCtxPool* ZSTDMT_createCCtxPool(unsigned nbThreads,
+                                            ZSTD_customMem cMem)
 {
-    ZSTDMT_CCtxPool* const cctxPool = (ZSTDMT_CCtxPool*) calloc(1, sizeof(ZSTDMT_CCtxPool) + (nbThreads-1)*sizeof(ZSTD_CCtx*));
+    ZSTDMT_CCtxPool* const cctxPool = (ZSTDMT_CCtxPool*) ZSTD_calloc(
+        sizeof(ZSTDMT_CCtxPool) + (nbThreads-1)*sizeof(ZSTD_CCtx*), cMem);
     if (!cctxPool) return NULL;
+    cctxPool->cMem = cMem;
     cctxPool->totalCCtx = nbThreads;
     cctxPool->availCCtx = 1;   /* at least one cctx for single-thread mode */
-    cctxPool->cctx[0] = ZSTD_createCCtx();
+    cctxPool->cctx[0] = ZSTD_createCCtx_advanced(cMem);
     if (!cctxPool->cctx[0]) { ZSTDMT_freeCCtxPool(cctxPool); return NULL; }
     DEBUGLOG(1, "cctxPool created, with %u threads", nbThreads);
     return cctxPool;
@@ -232,7 +246,7 @@ void ZSTDMT_compressChunk(void* jobDescription)
                  job->firstChunk, job->lastChunk, (U32)job->dictSize, (U32)job->srcSize);
     if (job->cdict) {  /* should only happen for first segment */
         size_t const initError = ZSTD_compressBegin_usingCDict_advanced(job->cctx, job->cdict, job->params.fParams, job->fullFrameSize);
-        if (job->cdict) DEBUGLOG(3, "using CDict ");
+        if (job->cdict) DEBUGLOG(3, "using CDict");
         if (ZSTD_isError(initError)) { job->cSize = initError; goto _endJob; }
     } else {  /* srcStart points at reloaded section */
         if (!job->firstChunk) job->params.fParams.contentSizeFlag = 0;  /* ensure no srcSize control */
@@ -293,43 +307,56 @@ struct ZSTDMT_CCtx_s {
     unsigned overlapRLog;
     unsigned long long frameContentSize;
     size_t sectionSize;
+    ZSTD_customMem cMem;
     ZSTD_CDict* cdict;
     ZSTD_CStream* cstream;
 };
 
-ZSTDMT_CCtx* ZSTDMT_createCCtx(unsigned nbThreads)
+ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbThreads, ZSTD_customMem cMem)
 {
-    ZSTDMT_CCtx* cctx;
+    ZSTDMT_CCtx* mtctx;
     U32 const minNbJobs = nbThreads + 2;
     U32 const nbJobsLog2 = ZSTD_highbit32(minNbJobs) + 1;
     U32 const nbJobs = 1 << nbJobsLog2;
-    DEBUGLOG(5, "nbThreads : %u  ; minNbJobs : %u ;  nbJobsLog2 : %u ;  nbJobs : %u  \n",
-            nbThreads, minNbJobs, nbJobsLog2, nbJobs);
+    DEBUGLOG(5, "nbThreads: %u ; minNbJobs: %u ; nbJobsLog2: %u ; nbJobs: %u",
+                nbThreads, minNbJobs, nbJobsLog2, nbJobs);
+
     if ((nbThreads < 1) | (nbThreads > ZSTDMT_NBTHREADS_MAX)) return NULL;
-    cctx = (ZSTDMT_CCtx*) calloc(1, sizeof(ZSTDMT_CCtx));
-    if (!cctx) return NULL;
-    cctx->nbThreads = nbThreads;
-    cctx->jobIDMask = nbJobs - 1;
-    cctx->allJobsCompleted = 1;
-    cctx->sectionSize = 0;
-    cctx->overlapRLog = 3;
-    cctx->factory = POOL_create(nbThreads, 1);
-    cctx->jobs = (ZSTDMT_jobDescription*) malloc(nbJobs * sizeof(*cctx->jobs));
-    cctx->buffPool = ZSTDMT_createBufferPool(nbThreads);
-    cctx->cctxPool = ZSTDMT_createCCtxPool(nbThreads);
-    if (!cctx->factory | !cctx->jobs | !cctx->buffPool | !cctx->cctxPool) {
-        ZSTDMT_freeCCtx(cctx);
+    if ((cMem.customAlloc!=NULL) ^ (cMem.customFree!=NULL))
+        /* invalid custom allocator */
+        return NULL;
+
+    mtctx = (ZSTDMT_CCtx*) ZSTD_calloc(sizeof(ZSTDMT_CCtx), cMem);
+    if (!mtctx) return NULL;
+    mtctx->cMem = cMem;
+    mtctx->nbThreads = nbThreads;
+    mtctx->jobIDMask = nbJobs - 1;
+    mtctx->allJobsCompleted = 1;
+    mtctx->sectionSize = 0;
+    mtctx->overlapRLog = 3;
+    mtctx->factory = POOL_create(nbThreads, 1);
+    mtctx->jobs = (ZSTDMT_jobDescription*)ZSTD_calloc(
+        nbJobs * sizeof(*mtctx->jobs), cMem);
+    mtctx->buffPool = ZSTDMT_createBufferPool(nbThreads, cMem);
+    mtctx->cctxPool = ZSTDMT_createCCtxPool(nbThreads, cMem);
+    if (!mtctx->factory | !mtctx->jobs | !mtctx->buffPool | !mtctx->cctxPool) {
+        ZSTDMT_freeCCtx(mtctx);
         return NULL;
     }
     if (nbThreads==1) {
-        cctx->cstream = ZSTD_createCStream();
-        if (!cctx->cstream) {
-            ZSTDMT_freeCCtx(cctx); return NULL;
+        mtctx->cstream = ZSTD_createCStream_advanced(cMem);
+        if (!mtctx->cstream) {
+            ZSTDMT_freeCCtx(mtctx); return NULL;
     }   }
-    pthread_mutex_init(&cctx->jobCompleted_mutex, NULL);   /* Todo : check init function return */
-    pthread_cond_init(&cctx->jobCompleted_cond, NULL);
-    DEBUGLOG(4, "mt_cctx created, for %u threads \n", nbThreads);
-    return cctx;
+    pthread_mutex_init(&mtctx->jobCompleted_mutex, NULL);   /* Todo : check init function return */
+    pthread_cond_init(&mtctx->jobCompleted_cond, NULL);
+    DEBUGLOG(4, "mt_cctx created, for %u threads", nbThreads);
+    return mtctx;
+}
+
+ZSTDMT_CCtx* ZSTDMT_createCCtx(unsigned nbThreads)
+{
+    return ZSTDMT_createCCtx_advanced(nbThreads, ZSTD_defaultCMem);
 }
 
 /* ZSTDMT_releaseAllJobResources() :
@@ -357,13 +384,13 @@ size_t ZSTDMT_freeCCtx(ZSTDMT_CCtx* mtctx)
     POOL_free(mtctx->factory);
     if (!mtctx->allJobsCompleted) ZSTDMT_releaseAllJobResources(mtctx); /* stop workers first */
     ZSTDMT_freeBufferPool(mtctx->buffPool);  /* release job resources into pools first */
-    free(mtctx->jobs);
+    ZSTD_free(mtctx->jobs, mtctx->cMem);
     ZSTDMT_freeCCtxPool(mtctx->cctxPool);
     ZSTD_freeCDict(mtctx->cdict);
     ZSTD_freeCStream(mtctx->cstream);
     pthread_mutex_destroy(&mtctx->jobCompleted_mutex);
     pthread_cond_destroy(&mtctx->jobCompleted_cond);
-    free(mtctx);
+    ZSTD_free(mtctx, mtctx->cMem);
     return 0;
 }
 
diff --git a/lib/compress/zstdmt_compress.h b/lib/compress/zstdmt_compress.h
index 27f78ee0..92f7d8d0 100644
--- a/lib/compress/zstdmt_compress.h
+++ b/lib/compress/zstdmt_compress.h
@@ -28,6 +28,7 @@
 
 typedef struct ZSTDMT_CCtx_s ZSTDMT_CCtx;
 ZSTDLIB_API ZSTDMT_CCtx* ZSTDMT_createCCtx(unsigned nbThreads);
+ZSTDLIB_API ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbThreads, ZSTD_customMem cMem);
 ZSTDLIB_API size_t ZSTDMT_freeCCtx(ZSTDMT_CCtx* cctx);
 
 ZSTDLIB_API size_t ZSTDMT_compressCCtx(ZSTDMT_CCtx* cctx,
diff --git a/lib/zstd.h b/lib/zstd.h
index d3a4d178..d4c24e7b 100644
--- a/lib/zstd.h
+++ b/lib/zstd.h
@@ -421,6 +421,8 @@ typedef struct {
 typedef void* (*ZSTD_allocFunction) (void* opaque, size_t size);
 typedef void  (*ZSTD_freeFunction) (void* opaque, void* address);
 typedef struct { ZSTD_allocFunction customAlloc; ZSTD_freeFunction customFree; void* opaque; } ZSTD_customMem;
+/* use this constant to defer to stdlib's functions */
+static const ZSTD_customMem ZSTD_defaultCMem = { NULL, NULL, NULL};
 
 /***************************************
 *  Frame size functions