Allow splitPoint==1.0 (using all samples for both training and testing)

This commit is contained in:
Jennifer Liu 2018-07-05 10:38:45 -07:00
parent 0881184c89
commit a085d1aae1
2 changed files with 14 additions and 10 deletions

View File

@ -543,10 +543,10 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer,
const unsigned kFirst = 0;
const size_t totalSamplesSize = COVER_sum(samplesSizes, kFirst, nbSamples);
/* Split samples into testing and training sets */
const unsigned nbTrainSamples = (unsigned)((double)nbSamples * splitPoint);
const unsigned nbTestSamples = nbSamples - nbTrainSamples;
const size_t trainingSamplesSize = COVER_sum(samplesSizes, kFirst, nbTrainSamples);
const size_t testSamplesSize = COVER_sum(samplesSizes, nbTrainSamples, nbSamples);
const unsigned nbTrainSamples = splitPoint < 1.0 ? (unsigned)((double)nbSamples * splitPoint) : nbSamples;
const unsigned nbTestSamples = splitPoint < 1.0 ? nbSamples - nbTrainSamples : nbSamples;
const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, kFirst, nbTrainSamples) : totalSamplesSize;
const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples, nbSamples) : totalSamplesSize;
/* Checks */
if (totalSamplesSize < MAX(d, sizeof(U64)) ||
totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) {
@ -559,12 +559,13 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer,
DISPLAYLEVEL(1, "Total number of training samples is %u and is invalid.", nbTrainSamples);
return 0;
}
/* Check if there's testing sample when splitPoint is not 1.0 */
if (nbTestSamples < 1 && splitPoint < 1.0) {
/* Check if there's testing sample */
if (nbTestSamples < 1) {
DISPLAYLEVEL(1, "Total number of testing samples is %u and is invalid.", nbTestSamples);
return 0;
}
if (nbTrainSamples + nbTestSamples != nbSamples) {
/* Check if nbTrainSamples plus nbTestSamples add up to nbSamples when splitPoint is less than 1*/
if (nbTrainSamples + nbTestSamples != nbSamples && splitPoint < 1.0) {
DISPLAYLEVEL(1, "nbTrainSamples plus nbTestSamples don't add up to nbSamples");
return 0;
}
@ -920,7 +921,8 @@ static void COVER_tryParameters(void *opaque) {
/* Allocate dst with enough space to compress the maximum sized sample */
{
size_t maxSampleSize = 0;
for (i = ctx->nbTrainSamples; i < ctx->nbSamples; ++i) {
i = parameters.splitPoint < 1.0 ? ctx->nbTrainSamples : 0;
for (; i < ctx->nbSamples; ++i) {
maxSampleSize = MAX(ctx->samplesSizes[i], maxSampleSize);
}
dstCapacity = ZSTD_compressBound(maxSampleSize);
@ -973,7 +975,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover(
/* constants */
const unsigned nbThreads = parameters->nbThreads;
const double splitPoint =
parameters->splitPoint <= 0.0 ? DEFAULT_SPLITPOINT : parameters->splitPoint;
(parameters->splitPoint <= 0.0 || parameters->splitPoint > 1.0) ? DEFAULT_SPLITPOINT : parameters->splitPoint;
const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d;
const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d;
const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k;
@ -991,7 +993,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover(
POOL_ctx *pool = NULL;
/* Checks */
if (splitPoint <= 0 || splitPoint >= 1) {
if (splitPoint <= 0 || splitPoint > 1) {
LOCALDISPLAYLEVEL(displayLevel, 1, "Incorrect parameters\n");
return ERROR(GENERIC);
}

View File

@ -425,6 +425,8 @@ rm tmp*
$ECHO "- Compare size of dictionary from 90% training samples with 80% training samples"
$ZSTD --train-cover=split=90 -r *.c ../programs/*.c
$ZSTD --train-cover=split=80 -r *.c ../programs/*.c
$ECHO "- Create dictionary using all samples for both training and testing"
$ZSTD --train-cover=split=100 -r *.c ../programs/*.c
$ECHO "\n===> legacy dictionary builder "