/*** Copyright (C) 2022 J Reece Wilson (a/k/a "Reece"). All rights reserved. File: ZSTDCompressor.hpp Date: 2022-2-15 Author: Reece ***/ #pragma once #include "zstd.h" namespace Aurora::Compression { struct ZSTDDeflate : public BaseStream { CompressionInfo meta; ZSTDDeflate(const CompressionInfo &meta) : meta(meta), BaseStream(meta.internalStreamSize) {} ~ZSTDDeflate() { if (auto dctx = AuExchange(cctx_, {})) { ZSTD_freeCCtx(dctx); } } bool Init(const AuSPtr &reader) override { size_t ret; this->reader_ = reader; this->cctx_ = ZSTD_createCCtx(); if (!this->cctx_) { SysPushErrorGen("Couldn't create compressor"); } ret = ZSTD_CCtx_setParameter(this->cctx_, ZSTD_c_compressionLevel, meta.compressionLevel); if (ZSTD_isError(ret)) { SysPushErrorGen("Invalid compression level"); return false; } ret = ZSTD_CCtx_setParameter(this->cctx_, ZSTD_c_checksumFlag, 1); if (ZSTD_isError(ret)) { SysPushErrorGen("Invalid option"); return false; } ret = ZSTD_CCtx_setParameter(this->cctx_, ZSTD_c_nbWorkers, AuMax(meta.threads, AuUInt8(1u))); if (ZSTD_isError(ret)) { SysPushErrorGen(); return false; } if (!this->_outbuffer) { SysPushErrorMem(); return false; } this->curPtr_ = this->din_; this->availIn_ = 0; SetArray(this->din_); return true; } AuStreamReadWrittenPair_t Ingest_s(AuUInt32 input) override { AuUInt32 length = AuUInt32(ZSTD_DStreamInSize()); AuUInt32 outFrameLength = AuUInt32(ZSTD_DStreamOutSize()); AuUInt32 done {}, read {}; while (read < input) { read += IngestForInPointer(this->reader_, this->curPtr_, this->availIn_, input - read); if (!this->availIn_) { return {read, done}; } this->input_ = ZSTD_inBuffer {this->curPtr_, this->availIn_, 0}; while (this->input_.pos < this->input_.size) { ZSTD_outBuffer output = {this->dout_, outFrameLength, 0}; auto ret = ZSTD_compressStream(this->cctx_, &output, &this->input_); if (ZSTD_isError(ret)) { SysPushErrorIO("Compression error: {}", ZSTD_getErrorName(ret)); this->availIn_ -= AuUInt32(output.pos); return AuMakePair(read, 0); } if (!output.pos) { continue; } done += AuUInt32(output.pos); if (!Write(reinterpret_cast(this->dout_), AuUInt32(output.pos))) { return AuMakePair(read, 0); } } this->availIn_ -= AuUInt32(this->input_.pos); } return {read, done}; } bool Flush() override { return RunFlush(ZSTD_e_continue); } bool Finish() override { return RunFlush(ZSTD_e_end); } bool RunFlush(ZSTD_EndDirective type) { AuUInt32 length = AuUInt32(ZSTD_DStreamInSize()); AuUInt32 outFrameLength = AuUInt32(ZSTD_DStreamOutSize()); if (!this->availIn_) { ZSTD_outBuffer output = { this->dout_, outFrameLength, 0 }; ZSTD_inBuffer input = { NULL, 0, 0 }; auto ret = ZSTD_compressStream2(this->cctx_, &output, &input, type); if (ZSTD_isError(ret)) { SysPushErrorIO("Compression error: {}", ZSTD_getErrorName(ret)); this->availIn_ -= AuUInt32(output.pos); return {}; } if (!Write(reinterpret_cast(output.dst), AuUInt32(output.pos))) { return false; } return true; } this->input_ = ZSTD_inBuffer {this->curPtr_, this->availIn_, 0}; while (this->input_.pos < this->input_.size) { ZSTD_outBuffer output = {this->dout_, outFrameLength, 0}; auto ret = ZSTD_compressStream(this->cctx_, &output, &this->input_); if (ZSTD_isError(ret)) { SysPushErrorIO("Compression error: {}", ZSTD_getErrorName(ret)); this->availIn_ -= AuUInt32(output.pos); return {}; } if (!Write(reinterpret_cast(this->dout_), AuUInt32(output.pos))) { return false; } } this->availIn_ -= AuUInt32(this->input_.pos); return true; } private: AuSPtr reader_; ZSTD_CCtx *cctx_ {}; char din_[ZSTD_BLOCKSIZE_MAX]; char dout_[ZSTD_BLOCKSIZE_MAX + 3 /*ZSTD_BLOCKHEADERSIZE*/]; char *curPtr_ {}; AuUInt32 availIn_ {}; ZSTD_inBuffer input_; }; }