/*** Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: ZSTDCompressor.hpp Date: 2022-2-15 Author: Reece ***/ #pragma once #include "zstd.h" namespace Aurora::Compression { struct ZSTDDeflate : BaseStream { CompressInfo meta; ZSTDDeflate(const CompressInfo &meta) : meta(meta), BaseStream(meta.uInternalStreamSize) {} ~ZSTDDeflate() { if (auto cctx = AuExchange(this->cctx_, {})) { ZSTD_freeCCtx(cctx); } } bool Init(const AuSPtr &pReader) override { AuUInt uRet; if (!this->IsValid()) { SysPushErrorMem(); return false; } this->pReader_ = pReader; this->cctx_ = ZSTD_createCCtx(); if (!this->cctx_) { SysPushErrorGen("Couldn't create compressor"); return false; } uRet = ZSTD_CCtx_setParameter(this->cctx_, ZSTD_c_compressionLevel, meta.uCompressionLevel); if (ZSTD_isError(uRet)) { SysPushErrorArg("Invalid compression level"); this->SetLastError(uRet, ZSTD_getErrorName(uRet)); return false; } uRet = ZSTD_CCtx_setParameter(this->cctx_, ZSTD_c_checksumFlag, meta.bErrorCheck ? 1 : 0); if (ZSTD_isError(uRet)) { SysPushErrorArg("Invalid option"); this->SetLastError(uRet, ZSTD_getErrorName(uRet)); return false; } if (meta.uOptQuality) { uRet = ZSTD_CCtx_setParameter(this->cctx_, ZSTD_c_strategy, meta.uOptQuality.value()); if (ZSTD_isError(uRet)) { SysPushErrorArg("ZSTD_c_strategy"); this->SetLastError(uRet, ZSTD_getErrorName(uRet)); return false; } } if (auto uThreadsMin = AuHwInfo::GetCPUInfo().uThreads) { meta.uThreads = AuMin(uThreadsMin, meta.uThreads); } if (meta.uThreads > 1) { uRet = ZSTD_CCtx_setParameter(this->cctx_, ZSTD_c_nbWorkers, AuMax(meta.uThreads, AuUInt8(1u))); if (ZSTD_isError(uRet)) { this->SetLastError(uRet, ZSTD_getErrorName(uRet)); return false; } } for (const auto &[a, b] : meta.options) { uRet = ZSTD_CCtx_setParameter(this->cctx_, (ZSTD_cParameter)a, b); if (ZSTD_isError(uRet)) { SysPushErrorArg("Compressor argument assignment {} = {}", a, b); this->SetLastError(uRet, ZSTD_getErrorName(uRet)); return false; } } this->pIterator_ = this->din_; this->uAvailableIn_ = 0; this->SetArray(this->din_); this->SetOutArray(this->dout_); return true; } AuStreamReadWrittenPair_t Ingest_s(AuUInt32 input) override { AuUInt32 uLength = AuUInt32(ZSTD_DStreamInSize()); AuUInt32 uOutFrameLength = AuUInt32(ZSTD_DStreamOutSize()); AuUInt32 done {}, read {}; if (!this->pReader_) { return {}; } while (read < input) { read += IngestForInPointer(this->pReader_, this->pIterator_, this->uAvailableIn_, input - read, this); if (!this->uAvailableIn_) { Flush(); return {read, done}; } this->input_ = ZSTD_inBuffer {this->pIterator_, this->uAvailableIn_, 0}; size_t uRet {}; bool bLastFrame = input - read == 0; while ((this->input_.pos != this->input_.size) || (/*bLastFrame && uRet*/ false)) { auto [pMainDOut, uMainDOutLength] = this->GetDOutPair(); ZSTD_outBuffer output = { pMainDOut, uMainDOutLength, 0 }; ZSTD_EndDirective mode = ZSTD_e_continue;// bLastFrame ? ZSTD_e_flush : ZSTD_e_continue; uRet = ZSTD_compressStream2(this->cctx_, &output, &this->input_, mode); if (ZSTD_isError(uRet)) { this->SetLastError(uRet, ZSTD_getErrorName(uRet)); this->uAvailableIn_ -= AuUInt32(output.pos); this->pReader_.reset(); return AuMakePair(read, 0); } this->bContd = true; this->bLastFrameHasNotFinished_ = bLastFrame && uRet; if (!output.pos) { continue; } done += AuUInt32(output.pos); if (!Write2(reinterpret_cast(pMainDOut), AuUInt32(output.pos))) { this->pReader_.reset(); this->SetLastError(0x69, "OOM"); return AuMakePair(read, 0); } } this->pIterator_ += this->input_.pos; this->uAvailableIn_ -= AuUInt32(this->input_.pos); } return {read, done}; } bool Flush() override { AuUInt32 uLength = AuUInt32(ZSTD_DStreamInSize()); AuUInt32 uOutFrameLength = AuUInt32(ZSTD_DStreamOutSize()); this->input_ = ZSTD_inBuffer { this->pIterator_, this->uAvailableIn_, 0 }; AuUInt uRet {}; while ((this->input_.pos < this->input_.size) || (this->bLastFrameHasNotFinished_) || (uRet)) { auto [pMainDOut, uMainDOutLength] = this->GetDOutPair(); ZSTD_outBuffer output = { pMainDOut, uMainDOutLength, 0 }; this->bLastFrameHasNotFinished_ = false; uRet = ZSTD_compressStream2(this->cctx_, &output, &this->input_, ZSTD_e_flush); if (ZSTD_isError(uRet)) { this->SetLastError(uRet, ZSTD_getErrorName(uRet)); this->uAvailableIn_ -= AuUInt32(output.pos); return {}; } if (!output.pos) { continue; } if (!Write2(reinterpret_cast(pMainDOut), AuUInt32(output.pos))) { SysPushErrorIO("Compression Out of Overhead"); this->pReader_.reset(); return false; } } this->pIterator_ += this->input_.pos; this->uAvailableIn_ -= this->input_.pos; this->bContd = false; return true;// RunFlush(ZSTD_e_continue); } bool Finish() override { AuUInt32 uOutFrameLength = AuUInt32(ZSTD_DStreamOutSize()); ZSTD_outBuffer output = { this->dout_, uOutFrameLength, 0 }; if (meta.uThreads == 1) { Flush(); } else { if (!this->bContd) { return true; } } size_t uRet; do { uRet = ZSTD_endStream(this->cctx_, &output); if (ZSTD_isError(uRet)) { this->SetLastError(uRet, ZSTD_getErrorName(uRet)); return {}; } if (!Write(reinterpret_cast(this->dout_), AuUInt32(output.pos))) { SysPushErrorIO("Compression Out of Overhead"); return false; } } while (uRet); this->bContd = false; return true; } private: AuSPtr pReader_; ZSTD_CCtx *cctx_ {}; char din_[ZSTD_BLOCKSIZE_MAX]; char dout_[ZSTD_COMPRESSBOUND(ZSTD_BLOCKSIZE_MAX) + 3 +/*ZSTD_BLOCKHEADERSIZE*/ + 4 /*32bit hash*/]; char *pIterator_ {}; AuUInt32 uAvailableIn_ {}; ZSTD_inBuffer input_ {}; bool bLastFrameHasNotFinished_ {}; bool bContd {}; }; }