Allow splitPoint==1.0 (using all samples for both training and testing)
This commit is contained in:
parent
0881184c89
commit
a085d1aae1
@ -543,10 +543,10 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer,
|
||||
const unsigned kFirst = 0;
|
||||
const size_t totalSamplesSize = COVER_sum(samplesSizes, kFirst, nbSamples);
|
||||
/* Split samples into testing and training sets */
|
||||
const unsigned nbTrainSamples = (unsigned)((double)nbSamples * splitPoint);
|
||||
const unsigned nbTestSamples = nbSamples - nbTrainSamples;
|
||||
const size_t trainingSamplesSize = COVER_sum(samplesSizes, kFirst, nbTrainSamples);
|
||||
const size_t testSamplesSize = COVER_sum(samplesSizes, nbTrainSamples, nbSamples);
|
||||
const unsigned nbTrainSamples = splitPoint < 1.0 ? (unsigned)((double)nbSamples * splitPoint) : nbSamples;
|
||||
const unsigned nbTestSamples = splitPoint < 1.0 ? nbSamples - nbTrainSamples : nbSamples;
|
||||
const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, kFirst, nbTrainSamples) : totalSamplesSize;
|
||||
const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples, nbSamples) : totalSamplesSize;
|
||||
/* Checks */
|
||||
if (totalSamplesSize < MAX(d, sizeof(U64)) ||
|
||||
totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) {
|
||||
@ -559,12 +559,13 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer,
|
||||
DISPLAYLEVEL(1, "Total number of training samples is %u and is invalid.", nbTrainSamples);
|
||||
return 0;
|
||||
}
|
||||
/* Check if there's testing sample when splitPoint is not 1.0 */
|
||||
if (nbTestSamples < 1 && splitPoint < 1.0) {
|
||||
/* Check if there's testing sample */
|
||||
if (nbTestSamples < 1) {
|
||||
DISPLAYLEVEL(1, "Total number of testing samples is %u and is invalid.", nbTestSamples);
|
||||
return 0;
|
||||
}
|
||||
if (nbTrainSamples + nbTestSamples != nbSamples) {
|
||||
/* Check if nbTrainSamples plus nbTestSamples add up to nbSamples when splitPoint is less than 1*/
|
||||
if (nbTrainSamples + nbTestSamples != nbSamples && splitPoint < 1.0) {
|
||||
DISPLAYLEVEL(1, "nbTrainSamples plus nbTestSamples don't add up to nbSamples");
|
||||
return 0;
|
||||
}
|
||||
@ -920,7 +921,8 @@ static void COVER_tryParameters(void *opaque) {
|
||||
/* Allocate dst with enough space to compress the maximum sized sample */
|
||||
{
|
||||
size_t maxSampleSize = 0;
|
||||
for (i = ctx->nbTrainSamples; i < ctx->nbSamples; ++i) {
|
||||
i = parameters.splitPoint < 1.0 ? ctx->nbTrainSamples : 0;
|
||||
for (; i < ctx->nbSamples; ++i) {
|
||||
maxSampleSize = MAX(ctx->samplesSizes[i], maxSampleSize);
|
||||
}
|
||||
dstCapacity = ZSTD_compressBound(maxSampleSize);
|
||||
@ -973,7 +975,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover(
|
||||
/* constants */
|
||||
const unsigned nbThreads = parameters->nbThreads;
|
||||
const double splitPoint =
|
||||
parameters->splitPoint <= 0.0 ? DEFAULT_SPLITPOINT : parameters->splitPoint;
|
||||
(parameters->splitPoint <= 0.0 || parameters->splitPoint > 1.0) ? DEFAULT_SPLITPOINT : parameters->splitPoint;
|
||||
const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d;
|
||||
const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d;
|
||||
const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k;
|
||||
@ -991,7 +993,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover(
|
||||
POOL_ctx *pool = NULL;
|
||||
|
||||
/* Checks */
|
||||
if (splitPoint <= 0 || splitPoint >= 1) {
|
||||
if (splitPoint <= 0 || splitPoint > 1) {
|
||||
LOCALDISPLAYLEVEL(displayLevel, 1, "Incorrect parameters\n");
|
||||
return ERROR(GENERIC);
|
||||
}
|
||||
|
@ -425,6 +425,8 @@ rm tmp*
|
||||
$ECHO "- Compare size of dictionary from 90% training samples with 80% training samples"
|
||||
$ZSTD --train-cover=split=90 -r *.c ../programs/*.c
|
||||
$ZSTD --train-cover=split=80 -r *.c ../programs/*.c
|
||||
$ECHO "- Create dictionary using all samples for both training and testing"
|
||||
$ZSTD --train-cover=split=100 -r *.c ../programs/*.c
|
||||
|
||||
$ECHO "\n===> legacy dictionary builder "
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user