From 634add9d10a9460fd1c33c8b0e49a96bc99f76fa Mon Sep 17 00:00:00 2001 From: Kai Zhao Date: Tue, 10 Dec 2024 08:49:38 -0500 Subject: [PATCH] update logic for dispatcher --- include/SZ3/api/impl/SZDispatcher.hpp | 34 ++++++++++++--------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/include/SZ3/api/impl/SZDispatcher.hpp b/include/SZ3/api/impl/SZDispatcher.hpp index 87b4d1ff..4b52fbc4 100644 --- a/include/SZ3/api/impl/SZDispatcher.hpp +++ b/include/SZ3/api/impl/SZDispatcher.hpp @@ -51,26 +51,22 @@ size_t SZ_compress_dispatcher(Config &conf, const T *data, uchar *cmpData, size_ if (conf.cmprAlgo == ALGO_LOSSLESS || !isCmpCapSufficient) { conf.cmprAlgo = ALGO_LOSSLESS; auto zstd = Lossless_zstd(); - cmpSize = zstd.compress(reinterpret_cast(data), conf.num * sizeof(T), cmpData, cmpCap); - } else { - // if lossy compression ratio < 3, test if lossless only mode has a better ratio than lossy - if (conf.num * sizeof(T) / 1.0 / cmpSize < 3) { - auto zstd = Lossless_zstd(); - auto zstdCmpCap = ZSTD_compressBound(conf.num * sizeof(T)); - auto zstdCmpData = static_cast(malloc(cmpCap)); - size_t zstdCmpSize = - zstd.compress(reinterpret_cast(data), conf.num * sizeof(T), zstdCmpData, zstdCmpCap); - if (zstdCmpSize < cmpSize) { - conf.cmprAlgo = ALGO_LOSSLESS; - if (zstdCmpSize > cmpCap) { - fprintf(stderr, "%s\n", SZ_ERROR_COMP_BUFFER_NOT_LARGE_ENOUGH); - throw std::length_error(SZ_ERROR_COMP_BUFFER_NOT_LARGE_ENOUGH); - } - memcpy(cmpData, zstdCmpData, zstdCmpSize); - cmpSize = zstdCmpSize; - } - free(zstdCmpData); + return zstd.compress(reinterpret_cast(data), conf.num * sizeof(T), cmpData, cmpCap); + } + + // if lossy compression ratio < 3, test if lossless only mode has a better ratio than lossy + if (conf.num * sizeof(T) / 1.0 / cmpSize < 3) { + auto zstd = Lossless_zstd(); + auto zstdCmpCap = ZSTD_compressBound(conf.num * sizeof(T)); + auto zstdCmpData = static_cast(malloc(zstdCmpCap)); + size_t zstdCmpSize = + zstd.compress(reinterpret_cast(data), conf.num * sizeof(T), zstdCmpData, zstdCmpCap); + if (zstdCmpSize < cmpSize && zstdCmpSize <= cmpCap) { + conf.cmprAlgo = ALGO_LOSSLESS; + memcpy(cmpData, zstdCmpData, zstdCmpSize); + cmpSize = zstdCmpSize; } + free(zstdCmpData); } return cmpSize; }