diff --git a/.github/workflows/static-pages.yml b/.github/workflows/static-pages.yml index 07671393..04f2f249 100644 --- a/.github/workflows/static-pages.yml +++ b/.github/workflows/static-pages.yml @@ -20,7 +20,7 @@ jobs: uses: actions/checkout@v4 - name: Install dependencies run: | - pip install -U torch fvcore einops psutil torchvision sphinx shibuya sphinx-autodoc-typehints sphinx-copybutton + pip install -U torch fvcore einops psutil torchvision sphinx shibuya sphinx-autodoc-typehints sphinx-copybutton sphinx-design - name: Sphinx build run: | PYTORCH_JIT=0 sphinx-build docs/source/ docs/build diff --git a/README.md b/README.md index 420bf31f..9903e42f 100644 --- a/README.md +++ b/README.md @@ -27,44 +27,48 @@ Explore the docs ยป

- Decode provided bitstreams + What's new in 3.4? ยท - Compression performance + Decode some bitstreams + ยท + Coding performance

Cool-chic (pronounced /kul สƒik/ as in French ๐Ÿฅ–๐Ÿง€๐Ÿท) is -is a low-complexity neural image codec based on overfitting. It offers image coding -performance competitive with *H.266/VVC for 2000 multiplications* per decoded +a low-complexity neural image codec based on overfitting. It offers image coding +performance competitive with **H.266/VVC for 1000 multiplications** per decoded pixel. + + +
+ +#### ๐Ÿ† **Coding performance**: Cool-chic compresses images as well as H.266/VVC ๐Ÿ† +#### ๐Ÿš€ **Fast CPU-only decoder**: Decode a 1280x720 image in 100 ms on CPU with our decoder written in C ๐Ÿš€ +#### ๐Ÿ”ฅ **Fixed-point decoder**: Fixed-point arithmetic at the decoder for bit-exact results on different hardwares ๐Ÿ”ฅ +#### ๐Ÿ–ผ๏ธ **I/O format**: Encode PNG, PPM and YUV file with a bitdepth of 8 to 16 bits ๐Ÿ–ผ๏ธ + +
+ # -### Current & future features - -- Coding performance - - โœ… On par with VVC for image coding - - โŒ Upcoming improved Cool-chic video -- I/O format - - โœ… PPM for 8-bit RGB images, yuv420 8-bit and 10-bit - - โŒ yuv444 - - โŒ Additional output precisions (12, 14 and 16-bit) - - โŒ Output PNG instead of PPM for the decoded images -- Decoder - - โœ… Fast C implementation - - โœ… Integer computation for the ARM - - โœ… Complete integerization - - โœ… Decrease memory footprint & faster decoding - -### Latest release: ๐Ÿš€ __Cool-chic 3.3: An even faster decoder!__ ๐Ÿš€ - -- Make the **CPU-only decoder** even faster. - - Decode a 720p image in **100 ms**, **2x faster** than Cool-chic 3.2 - - Full **integerization** of the decoder for replicability - - Reduce decoder **memory footprint** - - **Optimized** implementation of 3x3 convolutions & fusion of successive 1x1 convolutions + +
+ +### Latest release: ๐ŸŽ‰ __Cool-chic 3.4: 30% less complex!__ ๐ŸŽ‰ + +
+ +- New and improved latent **upsampling module** + - Leverage symmetric and separable convolution kernels to reduce complexity & parameters count + - Learn two filters per upsampling step instead of one for all upsampling steps +- 1% to 5% **rate reduction** for the same image quality +- **30% complexity reduction** using a smaller Auto-Regressive Module + - From 2000 MAC / decoded pixel to 1300 MAC / decoded pixel + - **10% faster** decoding speed Check-out the [release history](https://github.com/Orange-OpenSource/Cool-Chic/releases) to see previous versions of Cool-chic. @@ -97,42 +101,88 @@ You're good to go! The Cool-chic page provides [comprehensive rate-distortion results and compressed bitstreams](https://orange-opensource.github.io/Cool-Chic/getting_started/results.html) allowing to reproduce the results inside the ```results/``` directory. -| Dataset | Vs. Cool-chic 3.1 | Vs. [_C3_, Kim et al.](https://arxiv.org/abs/2312.02753) | Vs. HEVC (HM 16.20) | Vs. VVC (VTM 19.1) | Avg decoder MAC / pixel | Avg decoding time [ms] | -|------------------|----------------------------------------------|----------------------------------------------------------|----------------------------------------------|----------------------------------------------|----------------------------------|----------------------------------| -| kodak | - 1.9 % | - 3.4 % | - 16.4 % | + 4.5 % | 1880 | 96 | -| clic20-pro-valid | - 4.2 % | - 1.0 % | - 24.8 % | - 1.9 % | 1907 | 364 | -| jvet class B | - 7.2 % | / | - 10.8 % | + 19.5 % | 1803 | 260 | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
BD-rate of Cool-chic 3.4 vs. [%]Avg. decoder complexity
ChengELICCool-chic 3.3C3HEVC (HM 16)VVC (VTM 19)MAC / pixelCPU Time [ms]
kodak -4.2 % +7.5 % -0.9 % -4.3 % -17.2 % +3.4 % 130374
clic20-pro-valid -13.2 % -0.2 % -0.3 % -1.3 % -25.1 % -2.3 %
1357354
jvet //-0.2 %/-18.3 %+18.6 %1249143
+ +
+ +_Decoding time are obtained on a single CPU core of an an AMD EPYC 7282 16-Core Processor_ + +_PSNR is computed in the RGB domain for kodak and CLIC20, in the YUV420 domain for jvet_ + ### Kodak
- - Kodak rd results - Kodak performance complexity - + Kodak rd results

### CLIC20 Pro Valid
- - CLIC20 rd results - CLIC20 performance complexity - + CLIC20 rd results

### JVET Class B
- - JVET class B rd results - JVET class B performance complexity - + JVET class B rd results

+
+ # Thanks Special thanks go to Hyunjik Kim, Matthias Bauer, Lucas Theis, Jonathan Richard Schwarz and Emilien Dupont for their great work enhancing Cool-chic: [_C3: High-performance and low-complexity neural compression from a single image or video_, Kim et al.](https://arxiv.org/abs/2312.02753) @@ -154,7 +204,6 @@ Special thanks go to Hyunjik Kim, Matthias Bauer, Lucas Theis, Jonathan Richard
-
# diff --git a/cfg/dec/hop.cfg b/cfg/dec/hop.cfg index 7a0731f2..70b39f38 100644 --- a/cfg/dec/hop.cfg +++ b/cfg/dec/hop.cfg @@ -1,6 +1,5 @@ -arm = 24,2 -layers_synthesis = 40-1-linear-relu,3-1-linear-none,3-3-residual-relu,3-3-residual-none +arm = 16,2 +layers_synthesis = 48-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none n_ft_per_res = 1,1,1,1,1,1,1 -upsampling_kernel_size = 8 -static_upsampling_kernel = False - +ups_k_size = 8 +ups_preconcat_k_size = 7 \ No newline at end of file diff --git a/cfg/dec/lop.cfg b/cfg/dec/lop.cfg index 1252eaa5..adca16e9 100644 --- a/cfg/dec/lop.cfg +++ b/cfg/dec/lop.cfg @@ -1,5 +1,5 @@ arm = 8,2 -layers_synthesis = 16-1-linear-relu,3-1-linear-none,3-3-residual-relu,3-3-residual-none +layers_synthesis = 16-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none n_ft_per_res = 1,1,1,1,1,1,1 -upsampling_kernel_size = 4 -static_upsampling_kernel = False +ups_k_size = 8 +ups_preconcat_k_size = 7 diff --git a/cfg/dec/mop.cfg b/cfg/dec/mop.cfg index ff6e61de..a124b0cd 100644 --- a/cfg/dec/mop.cfg +++ b/cfg/dec/mop.cfg @@ -1,5 +1,5 @@ arm = 16,2 -layers_synthesis = 16-1-linear-relu,3-1-linear-none,3-3-residual-relu,3-3-residual-none +layers_synthesis = 16-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none n_ft_per_res = 1,1,1,1,1,1,1 -upsampling_kernel_size = 4 -static_upsampling_kernel = False +ups_k_size = 8 +ups_preconcat_k_size = 7 \ No newline at end of file diff --git a/cfg/dec/vlop.cfg b/cfg/dec/vlop.cfg index 616e5919..b473de55 100644 --- a/cfg/dec/vlop.cfg +++ b/cfg/dec/vlop.cfg @@ -1,5 +1,5 @@ arm = 8,1 -layers_synthesis = 8-1-linear-relu,3-1-linear-none,3-3-residual-none +layers_synthesis = 8-1-linear-relu,X-1-linear-none,X-3-residual-none n_ft_per_res = 1,1,1,1,1,1,1 -upsampling_kernel_size = 4 -static_upsampling_kernel = False +ups_k_size = 8 +ups_preconcat_k_size = 7 \ No newline at end of file diff --git a/coolchic/cpp/CMakeLists.txt b/coolchic/cpp/CMakeLists.txt new file mode 100644 index 00000000..0d4e2e05 --- /dev/null +++ b/coolchic/cpp/CMakeLists.txt @@ -0,0 +1,57 @@ + +add_executable(ccdec) + +set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR}) + +if(WIN32) + + message(STATUS "[ERROR] Cool-chic decoder not yet implemented for Windows...") + +# Check Apple first, then UNIX (Apple + Linux) so that if we enter the UNIX if +# it means that we're on Linux. +elseif(APPLE) + + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + + # Changes when compiling for arm64 Apple Mac: + # - Remove all *_avx2.cpp and *_avx512.cpp files + # - Remove the -mfa from the compilation options + # - Remove all the target_link_options... what is this for?? + # + # It only compiles using g++/gcc, not clang which defaults to + # an older version apparently? + # cmake -DCMAKE_C_COMPILER=/opt/homebrew/bin/gcc-13 -DCMAKE_CXX_COMPILER=/opt/homebrew/bin/g++-13 .. + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -g -Wall -Winline") + + target_sources(ccdec PRIVATE ccdecapi.cpp cc-bitstream.cpp cc-contexts.cpp arm_cpu.cpp syn_cpu.cpp BitStream.cpp TDecBinCoderCABAC.cpp Contexts.cpp) + + else() + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -g -mfma -Winline") + + # For now, we compile *_avx2.cpp and files, but they are + # excluded from ccdec.cpp using quick & dirty #ifdef __APPLE__ + target_sources(ccdec PRIVATE ccdecapi.cpp cc-bitstream.cpp cc-contexts.cpp arm_cpu.cpp arm_avx2.cpp ups_cpu.cpp ups_avx2.cpp syn_cpu.cpp syn_avx2.cpp BitStream.cpp TDecBinCoderCABAC.cpp Contexts.cpp) + + set_source_files_properties(arm_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") + set_source_files_properties(ups_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") + set_source_files_properties(syn_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") + + endif() + +elseif(UNIX) + + message(STATUS "Architecture: Linux") + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -g -mfma -Wall -Winline -DCCDEC_EXE -DCCDECAPI_AVX2_OPTIONAL") + + target_sources(ccdec PRIVATE ccdecapi.cpp cc-bitstream.cpp cc-contexts.cpp cc-frame-decoder.cpp frame-memory.cpp arm_cpu.cpp arm_avx2.cpp ups_cpu.cpp ups_avx2.cpp syn_cpu.cpp syn_avx2.cpp BitStream.cpp TDecBinCoderCABAC.cpp Contexts.cpp) + set(CMAKE_EXE_LINKER_FLAGS "-static") + + set_source_files_properties(arm_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") + set_source_files_properties(ups_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") + set_source_files_properties(syn_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") + +endif() + diff --git a/coolchic/cpp/cc-bitstream.cpp b/coolchic/cpp/cc-bitstream.cpp index aceae136..108a39e5 100644 --- a/coolchic/cpp/cc-bitstream.cpp +++ b/coolchic/cpp/cc-bitstream.cpp @@ -43,7 +43,7 @@ int read_bitdepth(FILE *bs) return (raw < 16) ? 8 : 10; } -bool cc_bs::open(std::string filename) +bool cc_bs::open(std::string filename, int verbosity) { m_filename = filename; m_f = fopen(filename.c_str(), "rb"); @@ -52,10 +52,10 @@ bool cc_bs::open(std::string filename) printf("Cannot open bitstream file %s for reading.\n", filename.c_str()); return false; } - return read_gop_header(); + return read_gop_header(verbosity); } -bool cc_bs::read_gop_header() +bool cc_bs::read_gop_header(int verbosity) { int raw; @@ -63,18 +63,22 @@ bool cc_bs::read_gop_header() m_gop_header.img_h = read_int_2(m_f); m_gop_header.img_w = read_int_2(m_f); raw = read_int_1(m_f); - m_gop_header.bitdepth = (raw>>4) == 0 ? 8 : 10; + //m_gop_header.bitdepth = (raw>>4) == 0 ? 8 : 10; + m_gop_header.bitdepth = (raw>>4) + 8; // !!! ASSUMING 8, 9..16 as 1st indicies. m_gop_header.frame_data_type = raw&0xF; m_gop_header.intra_period = read_int_1(m_f); m_gop_header.p_period = read_int_1(m_f); - printf("GOP HEADER:\n"); - printf(" n_bytes_header: %d\n", m_gop_header.n_bytes_header); - printf(" img_size: %d,%d\n", m_gop_header.img_h, m_gop_header.img_w); - printf(" frame_data_type: %d\n", m_gop_header.frame_data_type); - printf(" bitdepth: %d\n", m_gop_header.bitdepth); - printf(" intra_period: %d\n", m_gop_header.intra_period); - printf(" p_period: %d\n", m_gop_header.p_period); + if (verbosity >= 2) + { + printf("GOP HEADER:\n"); + printf(" n_bytes_header: %d\n", m_gop_header.n_bytes_header); + printf(" img_size: %d,%d\n", m_gop_header.img_h, m_gop_header.img_w); + printf(" frame_data_type: %d\n", m_gop_header.frame_data_type); + printf(" bitdepth: %d\n", m_gop_header.bitdepth); + printf(" intra_period: %d\n", m_gop_header.intra_period); + printf(" p_period: %d\n", m_gop_header.p_period); + } return true; } @@ -129,11 +133,11 @@ void print_lqi(char const *name, struct cc_bs_layer_quant_info &lqi) void print_syn(struct cc_bs_syn_layer &syn) { - printf("%d-%d-%s-%s", + printf("out%d-ks%d-%s-%s", syn.n_out_ft, syn.ks, syn.residual ? "residual" : "linear", syn.relu ? "relu" : "none"); } -bool read_frame_header(FILE *bs, struct cc_bs_frame_header &frame_header) +bool read_frame_header(FILE *bs, struct cc_bs_frame_header &frame_header, int verbosity) { int raw; @@ -142,10 +146,15 @@ bool read_frame_header(FILE *bs, struct cc_bs_frame_header &frame_header) raw = read_int_1(bs); frame_header.dim_arm = 8*(raw>>4); frame_header.n_hidden_layers_arm = raw&0xF; + + raw = read_int_1(bs); + frame_header.n_ups_kernel = raw>>4; + frame_header.ups_k_size = raw&0xF; raw = read_int_1(bs); - frame_header.upsampling_kern_size = raw>>1; - frame_header.static_upsampling_kernel = raw&1; + frame_header.n_ups_preconcat_kernel = raw>>4; + frame_header.ups_preconcat_k_size = raw&0xF; + frame_header.n_syn_branches = read_int_1(bs); frame_header.n_syn_layers = read_int_1(bs); frame_header.layers_synthesis.resize(frame_header.n_syn_layers); for (int i = 0; i < frame_header.n_syn_layers; i++) @@ -181,39 +190,45 @@ bool read_frame_header(FILE *bs, struct cc_bs_frame_header &frame_header) for (int i = 0; i < frame_header.latent_n_2d_grid; i++) frame_header.n_bytes_per_latent[i] = read_int_3(bs); - printf("FRAME HEADER\n"); - printf(" n_bytes_header: %d\n", frame_header.n_bytes_header); - printf(" latent_n_resolutions: %d\n", frame_header.n_latent_n_resolutions); - printf(" latent_n_2d_grid: %d\n", frame_header.latent_n_2d_grid); - printf(" n_bytes_per_latent:"); - for (int i = 0; i < frame_header.latent_n_2d_grid; i++) - printf(" %d", frame_header.n_bytes_per_latent[i]); - printf("\n"); - printf(" n_ft_per_latent:"); - for (int i = 0; i < frame_header.n_latent_n_resolutions; i++) - printf(" %d", frame_header.n_ft_per_latent[i]); - printf("\n"); - printf(" n_hidden_layers_arm: %d\n", frame_header.n_hidden_layers_arm); - printf(" dim_arm: %d\n", frame_header.dim_arm); - printf(" upsampling_kernel_size: %d\n", frame_header.upsampling_kern_size); - printf(" static_upsampling_kernel: %d\n", frame_header.static_upsampling_kernel); - printf(" flow_gain: %d\n", frame_header.flow_gain); - printf(" layers_synthesis:"); - for (int i = 0; i < frame_header.n_syn_layers; i++) + if (verbosity >= 2) { - printf(" "); - print_syn(frame_header.layers_synthesis[i]); + printf("FRAME HEADER\n"); + printf(" n_bytes_header: %d\n", frame_header.n_bytes_header); + printf(" display_index: %d\n", frame_header.display_index); + printf(" dim_arm: %d\n", frame_header.dim_arm); + printf(" n_hidden_layers_arm: %d\n", frame_header.n_hidden_layers_arm); + printf(" n_ups_kernel=%d, ups_ks=%d\n", frame_header.n_ups_kernel, frame_header.ups_k_size); + printf(" n_ups_preconcat_kernel=%d, ups_preconcat_ks=%d\n", frame_header.n_ups_preconcat_kernel, frame_header.ups_preconcat_k_size); + printf(" n_syn_branches: %d\n", frame_header.n_syn_branches); + printf(" layers_synthesis:"); + for (int i = 0; i < frame_header.n_syn_layers; i++) + { + printf(" "); + print_syn(frame_header.layers_synthesis[i]); + } + printf("\n"); + + printf(" flow_gain: %d\n", frame_header.flow_gain); + printf(" ac_max_val_nn: %d\n", frame_header.ac_max_val_nn); + printf(" ac_max_val_latent: %d\n", frame_header.ac_max_val_latent); + printf(" hls_sig_blksize: %d\n", frame_header.hls_sig_blksize); + + print_lqi("arm", frame_header.arm_lqi); + print_lqi("ups", frame_header.ups_lqi); + print_lqi("syn", frame_header.syn_lqi); + + printf(" latent_n_resolutions: %d\n", frame_header.n_latent_n_resolutions); + printf(" latent_n_2d_grid: %d\n", frame_header.latent_n_2d_grid); + + printf(" n_ft_per_latent:"); + for (int i = 0; i < frame_header.n_latent_n_resolutions; i++) + printf(" %d", frame_header.n_ft_per_latent[i]); + printf("\n"); + printf(" n_bytes_per_latent:"); + for (int i = 0; i < frame_header.latent_n_2d_grid; i++) + printf(" %d", frame_header.n_bytes_per_latent[i]); + printf("\n"); } - printf("\n"); - - print_lqi("arm", frame_header.arm_lqi); - print_lqi("ups", frame_header.ups_lqi); - print_lqi("syn", frame_header.syn_lqi); - printf(" ac_max_val_nn: %d\n", frame_header.ac_max_val_nn); - printf(" ac_max_val_latent: %d\n", frame_header.ac_max_val_latent); - printf(" hls_sig_blksize: %d\n", frame_header.hls_sig_blksize); - printf(" display_index: %d\n", frame_header.display_index); - return true; } @@ -231,10 +246,10 @@ std::vector get_coded(FILE *bs, int n_bytes) return result; } -struct cc_bs_frame *cc_bs::decode_frame() +struct cc_bs_frame *cc_bs::decode_frame(int verbosity) { cc_bs_frame *result = new cc_bs_frame(); - if (!read_frame_header(m_f, result->m_frame_header)) + if (!read_frame_header(m_f, result->m_frame_header, verbosity)) { delete result; return NULL; @@ -244,16 +259,14 @@ struct cc_bs_frame *cc_bs::decode_frame() result->m_arm_weights_hevc = get_coded(m_f, frame_header.arm_lqi.n_bytes_nn_weight); result->m_arm_biases_hevc = get_coded(m_f, frame_header.arm_lqi.n_bytes_nn_bias); result->m_ups_weights_hevc = get_coded(m_f, frame_header.ups_lqi.n_bytes_nn_weight); - - std::vector dummy_output_bias_ups(frame_header.ups_lqi.n_bytes_nn_bias); - dummy_output_bias_ups = get_coded(m_f, frame_header.ups_lqi.n_bytes_nn_bias); - + result->m_ups_biases_hevc = get_coded(m_f, frame_header.ups_lqi.n_bytes_nn_bias); result->m_syn_weights_hevc = get_coded(m_f, frame_header.syn_lqi.n_bytes_nn_weight); result->m_syn_biases_hevc = get_coded(m_f, frame_header.syn_lqi.n_bytes_nn_bias); - //printf("arm coded w%d b%d bytes w:%02x %02x %02x\n", result->m_arm_weights_hevc.size(), result->m_arm_biases_hevc.size(), result->m_arm_weights_hevc[0], result->m_arm_weights_hevc[1], result->m_arm_weights_hevc[2]); - //printf("ups coded w%d bytes w %02x %02x %02x\n", result->m_ups_weights_hevc.size(), result->m_ups_weights_hevc[0], result->m_ups_weights_hevc[1], result->m_ups_weights_hevc[2]); - //printf("syn coded w%d b%d bytes w %02x %02x %02x\n", result->m_syn_weights_hevc.size(), result->m_syn_biases_hevc.size(), result->m_syn_weights_hevc[0], result->m_syn_weights_hevc[1], result->m_syn_weights_hevc[2]); + //printf("arm coded w%ld b%ld bytes w:%02x %02x %02x\n", result->m_arm_weights_hevc.size(), result->m_arm_biases_hevc.size(), result->m_arm_weights_hevc[0], result->m_arm_weights_hevc[1], result->m_arm_weights_hevc[2]); + //printf("ups coded w%ld b%ld bytes w:%02x %02x %02x\n", result->m_ups_weights_hevc.size(), result->m_ups_biases_hevc.size(), result->m_ups_weights_hevc[0], result->m_ups_weights_hevc[1], result->m_ups_weights_hevc[2]); + //printf("syn coded w%ld b%ld bytes w %02x %02x %02x\n", result->m_syn_weights_hevc.size(), result->m_syn_biases_hevc.size(), result->m_syn_weights_hevc[0], result->m_syn_weights_hevc[1], result->m_syn_weights_hevc[2]); + //fflush(stdout); for (int i = 0; i < frame_header.n_latent_n_resolutions; i++) result->m_latents_hevc.emplace_back(get_coded(m_f, frame_header.n_bytes_per_latent[i])); diff --git a/coolchic/cpp/cc-bitstream.h b/coolchic/cpp/cc-bitstream.h index 82de2b86..712a2ac5 100644 --- a/coolchic/cpp/cc-bitstream.h +++ b/coolchic/cpp/cc-bitstream.h @@ -50,9 +50,14 @@ struct cc_bs_frame_header std::vector n_ft_per_latent; int n_hidden_layers_arm; int dim_arm; - int upsampling_kern_size; - int static_upsampling_kernel; + + int n_ups_kernel; + int ups_k_size; + int n_ups_preconcat_kernel; + int ups_preconcat_k_size; + int flow_gain; + int n_syn_branches; int n_syn_layers; std::vector layers_synthesis; struct cc_bs_layer_quant_info arm_lqi; @@ -66,16 +71,14 @@ struct cc_bs_frame_header int hls_sig_blksize; }; -//bool read_gop_header(FILE *bitstream, struct bitstream_gop_header &gop_header); -//bool read_frame_header(FILE *bitstream, struct bitstream_frame_header &frame_header); - struct cc_bs_frame { struct cc_bs_frame_header m_frame_header; std::vector m_arm_weights_hevc; std::vector m_arm_biases_hevc; - std::vector m_ups_weights_hevc; - std::vector m_syn_weights_hevc; - std::vector m_syn_biases_hevc; + std::vector m_ups_weights_hevc; // both lb and hb + std::vector m_ups_biases_hevc; // both lb and hb + std::vector m_syn_weights_hevc; // all branches + std::vector m_syn_biases_hevc; // all branches std::vector> m_latents_hevc; }; @@ -83,10 +86,10 @@ class cc_bs { public: cc_bs() { m_f = NULL; } ~cc_bs() { if (m_f != NULL) fclose(m_f); } - bool open(std::string filename); - struct cc_bs_frame *decode_frame(); + bool open(std::string filename, int verbosity = 0); + struct cc_bs_frame *decode_frame(int verbosity = 0); private: - bool read_gop_header(); + bool read_gop_header(int verbosity = 0); public: std::string m_filename; diff --git a/coolchic/cpp/cc-frame-decoder.cpp b/coolchic/cpp/cc-frame-decoder.cpp index d19bc353..5190d2f5 100644 --- a/coolchic/cpp/cc-frame-decoder.cpp +++ b/coolchic/cpp/cc-frame-decoder.cpp @@ -1,4 +1,8 @@ +#ifdef CCDECAPI_AVX2_OPTIONAL +#define CCDECAPI_AVX2 +#endif + #include "common.h" #include "TDecBinCoderCABAC.h" #include "cc-contexts.h" @@ -13,11 +17,13 @@ #endif #include "arm_cpu.h" +#include "ups_cpu.h" #include "syn_cpu.h" extern float time_arm_seconds; extern float time_ups_seconds; extern float time_syn_seconds; +extern float time_blend_seconds; int const Q_STEP_ARM_WEIGHT_SHIFT[] = { 8, // 1.0/(1<<8), @@ -162,6 +168,11 @@ void decode_weights_qi(int32_t *result, TDecBinCABAC *cabac, int count, int n_we // want val << precision >> q_step_shift if (precision > q_step_shift) tmp <<= (precision-q_step_shift); + else if (precision < q_step_shift) + { + printf("decoding weights: qstepshift %d > precision %d\n", q_step_shift, precision); + exit(1); + } result[i] = tmp; } @@ -173,61 +184,18 @@ void decode_weights_qi(weights_biases &result, TDecBinCABAC *cabac, int count, i decode_weights_qi(result.data, cabac, count, n_weights, q_step_shift, precision); } -// applying 4 small kernels, ks here is upsampling_kernel/2 -//void custom_conv_ups_zxz1_cpu(int ks, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int32_t *out, int h_target, int w_target, int stride_out) -void custom_conv_ups_zxz1_cpu(int ks, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int stride_out, int32_t *out, int ups_mul_precision) -{ - int const kstride = 1; - int offs0 = 0; - - for (int y = 0; y < h_in-ks+1; y += kstride, offs0 += stride_in) - { - for (int vf = 0; vf < 2; vf++, out += stride_out-(w_in-ks+1)*2) // vertical filter choice on this line. - { - int offs = offs0; - for (int x = 0; x < w_in-ks+1; x += kstride, offs += kstride) - { - for (int hf = 0; hf < 2; hf++) // horizontal filter choice at this point - { - int32_t sum = 0; - int32_t *k = kw+(vf*2+hf)*(ks*ks); - int offs2 = offs; - for (int yy = 0; yy < ks; yy++, offs2 += stride_in-ks) - for (int xx = 0; xx < ks; xx++) - { - sum += in[offs2++]*(*k++); - } - *out++ = sum >> ups_mul_precision; - } - } - } - } -} - -// upsamplingx2: applies 4 smaller kernels sized [ks/2][ks/2] -// first_ups implies we are processing ARM output (.8) rather than upsampling output (.12). -void custom_upsample_4(int ks, int32_t *weightsx4xpose, int h_in, int w_in, int stride_in, int32_t *in, int stride_target, int32_t *out, bool first_ups) +// We take a kernel size ks, but only read (ks+1)/2 weights. These weights are mirrored to produce the full kernel. +void decode_upsweights_qi(weights_biases &result, TDecBinCABAC *cabac, int count, int ks, int q_step_shift, int precision) { -#ifdef CCDECAPI_AVX2 - if (ks == 8) - { - if (first_ups) - ups_4x4x4_fromarm_avx2(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out); - else - ups_4x4x4_fromups_avx2(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out); - } - else if (ks == 4) + int nw = (ks+1)/2; // number of weights to read. + result.update_to(ks); // we allocate to allow mirroring. + decode_weights_qi(result.data, cabac, count, nw, q_step_shift, precision); // read the 1st half. + // mirror last half + for (int i = 0; i < nw/2*2; i++) { - if (first_ups) - ups_4x2x2_fromarm_avx2(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out); - else - ups_4x2x2_fromups_avx2(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out); - } - else -#endif - { - custom_conv_ups_zxz1_cpu(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out, first_ups ? ARM_PRECISION : UPS_PRECISION); + result.data[ks-1-i] = result.data[i]; } + fflush(stdout); } void cc_frame_decoder::read_arm(struct cc_bs_frame *frame_symbols) @@ -287,149 +255,58 @@ void cc_frame_decoder::read_ups(struct cc_bs_frame *frame_symbols) { struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header; - int const ups_ks = frame_header.upsampling_kern_size; - int32_t bucket[ups_ks*ups_ks]; // we transpose from this. - - if (frame_header.static_upsampling_kernel != 0) + int n_ups = frame_header.n_ups_kernel; + int n_ups_preconcat = frame_header.n_ups_preconcat_kernel; + if (n_ups != m_ups_n) { -#if 0 // generate text. - float bilinear_kernel[4 * 4] = { - 0.0625, 0.1875, 0.1875, 0.0625, - 0.1875, 0.5625, 0.5625, 0.1875, - 0.1875, 0.5625, 0.5625, 0.1875, - 0.0625, 0.1875, 0.1875, 0.0625 - }; - - float bicubic_kernel[8 * 8] = { - 0.0012359619 , 0.0037078857 ,-0.0092010498 ,-0.0308990479 ,-0.0308990479 ,-0.0092010498 , 0.0037078857 , 0.0012359619, - 0.0037078857 , 0.0111236572 ,-0.0276031494 ,-0.0926971436 ,-0.0926971436 ,-0.0276031494 , 0.0111236572 , 0.0037078857, - -0.0092010498 ,-0.0276031494 , 0.0684967041 , 0.2300262451 , 0.2300262451 , 0.0684967041 ,-0.0276031494 ,-0.0092010498, - -0.0308990479 ,-0.0926971436 , 0.2300262451 , 0.7724761963 , 0.7724761963 , 0.2300262451 ,-0.0926971436 ,-0.0308990479, - -0.0308990479 ,-0.0926971436 , 0.2300262451 , 0.7724761963 , 0.7724761963 , 0.2300262451 ,-0.0926971436 ,-0.0308990479, - -0.0092010498 ,-0.0276031494 , 0.0684967041 , 0.2300262451 , 0.2300262451 , 0.0684967041 ,-0.0276031494 ,-0.0092010498, - 0.0037078857 , 0.0111236572 ,-0.0276031494 ,-0.0926971436 ,-0.0926971436 ,-0.0276031494 , 0.0111236572 , 0.0037078857, - 0.0012359619 , 0.0037078857 ,-0.0092010498 ,-0.0308990479 ,-0.0308990479 ,-0.0092010498 , 0.0037078857 , 0.0012359619 - }; - // generate text. - printf("int32_t bilinear_kernel[4 * 4] = {\n"); - for (int ky = 0; ky < 4; ky++) - { - printf(" "); - for (int kx = 0; kx < 4; kx++) - printf("%d%s", (int32_t)(bilinear_kernel[ky*4+kx]*(1< &bs_fifo_weights = bs_weights.getFifo(); - - TDecBinCABAC cabac_weights; + delete[] m_upsw_preconcat; + m_upsw_preconcat = new weights_biases[n_ups_preconcat]; + m_ups_n_preconcat = n_ups_preconcat; + } - bs_fifo_weights = frame_symbols->m_ups_weights_hevc; - cabac_weights.init(&bs_weights); - cabac_weights.start(); - struct cc_bs_layer_quant_info &ups_lqi = frame_header.ups_lqi; + InputBitstream bs_weights; + std::vector &bs_fifo_weights = bs_weights.getFifo(); - int q_step_w_shift = Q_STEP_UPS_SHIFT[ups_lqi.q_step_index_nn_weight]; + TDecBinCABAC cabac_weights; - decode_weights_qi(&bucket[0], &cabac_weights, ups_lqi.scale_index_nn_weight, frame_header.upsampling_kern_size*frame_header.upsampling_kern_size, q_step_w_shift, UPS_PRECISION); - } + bs_fifo_weights = frame_symbols->m_ups_weights_hevc; + cabac_weights.init(&bs_weights); + cabac_weights.start(); + struct cc_bs_layer_quant_info &ups_lqi = frame_header.ups_lqi; - // transpose. - int32_t bucket_t[ups_ks*ups_ks]; + int q_step_w_shift = Q_STEP_UPS_SHIFT[ups_lqi.q_step_index_nn_weight]; - int idx = 0; - for (int y = 0; y < ups_ks; y++) + // read ups layers + for (int lidx = 0; lidx < m_ups_n; lidx++) { - for (int x = 0; x < ups_ks; x++, idx++) - { - bucket_t[idx] = bucket[(ups_ks-1-y)*ups_ks+(ups_ks-1-x)]; - } + decode_upsweights_qi(m_upsw[lidx], &cabac_weights, ups_lqi.scale_index_nn_weight, frame_header.ups_k_size, q_step_w_shift, UPS_PRECISION); } - - // extract 4 smaller filters. 8x8 -> 4x 4x4 - // f0 operates at origin, y=0, x=0 - // f1 operates at y=0, x=1 - // f2 operates at y=1, x=0 - // f3 operates at y=1, x=1 - m_upsw_t_4x4.update_to(ups_ks*ups_ks); - int ups_ks2 = ups_ks/2; - - for (int f = 0; f < 4; f++) + // read ups_preconcat layers. + for (int lidx = 0; lidx < m_ups_n_preconcat; lidx++) { - int fbase_y = 1-f/2; - int fbase_x = 1-f%2; - for (int y = 0; y < ups_ks/2; y++) - { - for (int x = 0; x < ups_ks/2; x++) - { - m_upsw_t_4x4.data[f*ups_ks2*ups_ks2 + y*ups_ks2 + x] = bucket_t[(fbase_y+2*y)*ups_ks+fbase_x+2*x]; - } - } + decode_upsweights_qi(m_upsw_preconcat[lidx], &cabac_weights, ups_lqi.scale_index_nn_weight, frame_header.ups_preconcat_k_size, q_step_w_shift, UPS_PRECISION); } - } void cc_frame_decoder::read_syn(struct cc_bs_frame *frame_symbols) { struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header; - if (frame_header.n_syn_layers != m_syn_n_layers) + if (frame_header.n_syn_layers != m_syn_n_layers || frame_header.n_syn_branches != m_syn_n_branches) { - m_syn_n_layers = frame_header.n_syn_layers; delete [] m_synw; delete [] m_synb; - m_synw = new weights_biases[m_syn_n_layers]; - m_synb = new weights_biases[m_syn_n_layers]; + m_syn_n_branches = frame_header.n_syn_branches; + m_syn_n_layers = frame_header.n_syn_layers; + m_syn_blends.update_to(m_syn_n_layers); + m_synw = new weights_biases[m_syn_n_branches*m_syn_n_layers]; + m_synb = new weights_biases[m_syn_n_branches*m_syn_n_layers]; } InputBitstream bs_weights; @@ -451,38 +328,39 @@ void cc_frame_decoder::read_syn(struct cc_bs_frame *frame_symbols) int q_step_w_shift = Q_STEP_SYN_WEIGHT_SHIFT[syn_lqi.q_step_index_nn_weight]; int q_step_b_shift = Q_STEP_SYN_BIAS_SHIFT[syn_lqi.q_step_index_nn_bias]; - int n_in_ft = frame_header.n_latent_n_resolutions; // !!! features per layer - for (int idx = 0; idx < frame_header.n_syn_layers; idx++) + // blend values. + if (m_syn_n_branches > 1) { - struct cc_bs_syn_layer &syn = frame_header.layers_synthesis[idx]; - int n_weights = n_in_ft * syn.ks*syn.ks * syn.n_out_ft; - int n_biases = syn.n_out_ft; + decode_weights_qi(m_syn_blends, &cabac_weights, syn_lqi.scale_index_nn_weight, m_syn_n_branches, q_step_w_shift, SYN_WEIGHT_PRECISION); + } - decode_weights_qi(m_synw[idx], &cabac_weights, syn_lqi.scale_index_nn_weight, n_weights, q_step_w_shift, SYN_WEIGHT_PRECISION); - decode_weights_qi(m_synb[idx], &cabac_biases, syn_lqi.scale_index_nn_bias, n_biases, q_step_b_shift, SYN_WEIGHT_PRECISION*2); + // layer weights. + for (int bidx = 0; bidx < frame_header.n_syn_branches; bidx++) + { + int n_in_ft = frame_header.n_latent_n_resolutions; // !!! features per layer + for (int lidx = 0; lidx < frame_header.n_syn_layers; lidx++) + { + struct cc_bs_syn_layer &syn = frame_header.layers_synthesis[lidx]; + int n_weights = n_in_ft * syn.ks*syn.ks * syn.n_out_ft; + int n_biases = syn.n_out_ft; - n_in_ft = syn.n_out_ft; + decode_weights_qi(m_synw[get_syn_idx(bidx, lidx)], &cabac_weights, syn_lqi.scale_index_nn_weight, n_weights, q_step_w_shift, SYN_WEIGHT_PRECISION); + decode_weights_qi(m_synb[get_syn_idx(bidx, lidx)], &cabac_biases, syn_lqi.scale_index_nn_bias, n_biases, q_step_b_shift, SYN_WEIGHT_PRECISION*2); + + n_in_ft = syn.n_out_ft; + } } } // check to see if the leading couple of syns are compatible with being fused // together. -// 8, 16, 40 are checked. 8 is marginal (base mem allocation is 7, 8 is not much more). +// Memory increase avoidance is so important we use cpu-mode in avx2 if it's +// not otherwise optimized. bool cc_frame_decoder::can_fuse(struct cc_bs_frame *frame_symbols) { auto &synthesis = frame_symbols->m_frame_header.layers_synthesis; bool fused = synthesis[0].ks == 1 && synthesis[1].ks == 1; - if (!fused) - return false; - - // Check for compatible numbers. - int n_syn_in = frame_symbols->m_frame_header.n_latent_n_resolutions; - int n_hidden = synthesis[0].n_out_ft; - int n_syn_out = synthesis[1].n_out_ft; - fused = (n_syn_in == 7) - && (n_hidden == 8 || n_hidden == 16 || n_hidden == 40) - && (n_syn_out == 3 || n_syn_out == 6 || n_syn_out == 9); return fused; } @@ -490,16 +368,13 @@ void cc_frame_decoder::check_allocations(struct cc_bs_frame *frame_symbols) { struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header; - // we allocate a max-sized (ignoring padding) buffer, suitable for holding - // upsample final output and synthesis outputs during syn processing. int const latent_n_resolutions = frame_header.n_latent_n_resolutions; int n_max_planes = latent_n_resolutions; m_arm_pad = 4; // by how much the arm frame is padded. - m_ups_pad = 8/2+1; // kernel size 8 upsampling. - m_max_pad = std::max(m_arm_pad, m_ups_pad); // during arm, ups and syn. // !!! we use 5 for max ups (5) (ks=8)/2+1, and arm (4) (cxt_row_col) + m_ups_pad = (std::max(frame_header.ups_k_size, frame_header.ups_preconcat_k_size)+1)/2; - bool need_ups_2 = false; // we only need this if we cannot do in-place convolution. Ie, ks 1 or 3 are in-place, otherwise not. + int syn_max_ks = 1; for (int syn_idx = 0; syn_idx < frame_header.n_syn_layers; syn_idx++) { if (syn_idx == 0 && can_fuse(frame_symbols)) @@ -510,24 +385,27 @@ void cc_frame_decoder::check_allocations(struct cc_bs_frame *frame_symbols) int n_pad = frame_header.layers_synthesis[syn_idx].ks/2; if (n_pad > m_max_pad) m_max_pad = n_pad; - if (frame_header.layers_synthesis[syn_idx].ks != 1 && frame_header.layers_synthesis[syn_idx].ks != 3) - need_ups_2 = true; + syn_max_ks = std::max(syn_max_ks, frame_header.layers_synthesis[syn_idx].ks); } - printf("MAX PLANES SET TO %d; MAX PAD SET TO %d\n", n_max_planes, m_max_pad); - m_ups_1.update_to(m_gop_header.img_h, m_gop_header.img_w, n_max_planes, m_max_pad); - // we do not (normally!) need m_ups_2. - if (need_ups_2) - { - printf("ups_2 for syn allocated\n"); - m_ups_2.update_to(m_gop_header.img_h, m_gop_header.img_w, n_max_planes, m_max_pad); - } + int syn_pad = (syn_max_ks+1)/2; + m_max_pad = std::max(std::max(m_arm_pad, m_ups_pad), syn_pad); // during arm, ups and syn. + + if (m_verbosity >= 3) + printf("MAX PLANES SET TO %d; MAX PAD SET TO %d\n", n_max_planes, m_max_pad); + + // !!! allocate all for the moment for bisyn. + m_syn_1.update_to(m_gop_header.img_h, m_gop_header.img_w, m_max_pad, n_max_planes); + m_syn_2.update_to(m_gop_header.img_h, m_gop_header.img_w, m_max_pad, n_max_planes); + m_syn_3.update_to(m_gop_header.img_h, m_gop_header.img_w, m_max_pad, n_max_planes); + m_syn_tmp.update_to(m_gop_header.img_h, m_gop_header.img_w, m_max_pad, n_max_planes); // pyramid sizes. m_zero_layer.resize(frame_header.n_latent_n_resolutions); m_h_pyramid.resize(frame_header.n_latent_n_resolutions); m_w_pyramid.resize(frame_header.n_latent_n_resolutions); - // pyramid storage -- enough pad for arm and ups. no highest res layer. + + // pyramid storage -- enough pad for arm and ups. includes highest res layer. m_plane_pyramid.resize(frame_header.n_latent_n_resolutions); for (int layer_number = 0, h_grid = m_gop_header.img_h, w_grid = m_gop_header.img_w; @@ -536,15 +414,52 @@ void cc_frame_decoder::check_allocations(struct cc_bs_frame *frame_symbols) { m_h_pyramid[layer_number] = h_grid; m_w_pyramid[layer_number] = w_grid; - - // we do not allocate the full-size. - if (layer_number == 0) - m_plane_pyramid[layer_number].update_to(0, 0, 0, 0); // not used. - else - m_plane_pyramid[layer_number].update_to(h_grid, w_grid, 1, m_max_pad); + m_plane_pyramid[layer_number].update_to(h_grid, w_grid, m_max_pad, 1); } + + // We need a couple of extra buffers for pyramid upsampling and refinement. It's not in-place. + // 1/4 sized for internal refinement target. (Full-sized goes directly to syn-land.) + m_ups_h2w2.update_to(m_h_pyramid[1], m_w_pyramid[1], m_ups_pad, 1); + // full-height, full-width for actual upsample. + m_ups_hw.update_to(m_h_pyramid[0], m_w_pyramid[0], m_ups_pad, 1); } +void ups_refine_cpu(int ks_param, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp) +{ + if (ks_param == 7) + ups_refine_ks7_cpu(ks_param, kw, in, out, ups_src_precision, tmp); + else + ups_refine_ksX_cpu(ks_param, kw, in, out, ups_src_precision, tmp); +} + +void ups_upsample_cpu(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp) +{ + if (ksx2 == 8) + ups_upsample_ks8_cpu(ksx2, kw, in, out, out_plane, ups_src_precision, tmp); + else + ups_upsample_ksX_cpu(ksx2, kw, in, out, out_plane, ups_src_precision, tmp); +} + +#ifdef CCDECAPI_AVX2 +void ups_refine_avx2(int ks_param, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp) +{ + if (ks_param == 7) + ups_refine_ks7_avx2(ks_param, kw, in, out, ups_src_precision, tmp); + else + ups_refine_ksX_avx2(ks_param, kw, in, out, ups_src_precision, tmp); +} + +void ups_upsample_avx2(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp) +{ + if (ksx2 == 8 && ups_src_precision == ARM_PRECISION) + ups_upsample_ks8_ARMPREC_avx2(ksx2, kw, in, out, out_plane, ups_src_precision, tmp); + else if (ksx2 == 8 && ups_src_precision == UPS_PRECISION) + ups_upsample_ks8_UPSPREC_avx2(ksx2, kw, in, out, out_plane, ups_src_precision, tmp); + else + ups_upsample_ksX_avx2(ksx2, kw, in, out, out_plane, ups_src_precision, tmp); +} +#endif + void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols) { struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header; @@ -556,7 +471,8 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols) int const h_grid = m_h_pyramid[layer_number]; int const w_grid = m_w_pyramid[layer_number]; - printf("starting layer %d\n", layer_number); + if (m_verbosity >= 3) + printf("arm:starting layer %d\n", layer_number); fflush(stdout); auto &bytes = frame_symbols->m_latents_hevc[layer_number]; @@ -568,9 +484,9 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols) } m_zero_layer[layer_number] = false; - frame_memory *dest = layer_number == 0 ? &m_ups_1 : &m_plane_pyramid[layer_number]; + //frame_memory *dest = layer_number == 0 ? &m_syn_1 : &m_plane_pyramid[layer_number]; + frame_memory *dest = &m_plane_pyramid[layer_number]; dest->zero_pad(0, m_arm_pad); - //dest->print_ranges("armin", 1, ARM_PRECISION); // BAC decoding: InputBitstream bsBAC; @@ -587,7 +503,7 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols) const auto time_arm_start = std::chrono::steady_clock::now(); #ifdef CCDECAPI_AVX2 - if (frame_header.dim_arm == 8) + if (m_use_avx2 && frame_header.dim_arm == 8) custom_conv_11_int32_avx2_8_X_X( m_mlpw_t, m_mlpb, &m_mlpwOUT, &m_mlpbOUT, @@ -596,7 +512,7 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols) h_grid, w_grid, (dest->stride-w_grid)/2, bac_context ); - else if (frame_header.dim_arm == 16) + else if (m_use_avx2 && frame_header.dim_arm == 16) custom_conv_11_int32_avx2_16_X_X( m_mlpw_t, m_mlpb, &m_mlpwOUT, &m_mlpbOUT, @@ -605,7 +521,7 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols) h_grid, w_grid, (dest->stride-w_grid)/2, bac_context ); - else if (frame_header.dim_arm == 24) + else if (m_use_avx2 && frame_header.dim_arm == 24) custom_conv_11_int32_avx2_24_X_X( m_mlpw_t, m_mlpb, &m_mlpwOUT, &m_mlpbOUT, @@ -614,7 +530,7 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols) h_grid, w_grid, (dest->stride-w_grid)/2, bac_context ); - else if (frame_header.dim_arm == 32) + else if (m_use_avx2 && frame_header.dim_arm == 32) custom_conv_11_int32_avx2_32_X_X( m_mlpw_t, m_mlpb, &m_mlpwOUT, &m_mlpbOUT, @@ -636,46 +552,30 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols) bac_context ); - //dest->print_ranges("armout", 1, ARM_PRECISION); - -#if UPS_PRECISION > ARM_PRECISION - if (layer_number == 0) - { - // in-place precision change for high-resolution layer. - // upsampling is not used here, which normally changes the .8 to .12. - // we do it by hand. - int32_t *src = dest->plane_origin(0); - int const lshift = UPS_PRECISION-ARM_PRECISION; - for (int y = 0; y < h_grid; y++, src += dest->stride-w_grid) - { - for (int x = 0; x < w_grid; x++) - { - *src++ <<= lshift; - } - } -#else - what system is this -#endif - } - const auto time_arm_done = std::chrono::steady_clock::now(); const std::chrono::duration arm_elapsed = (time_arm_done-time_arm_start); time_arm_seconds += (float)arm_elapsed.count(); } // layer. - printf("arm done!\n"); + if (m_verbosity >= 100) + { + printf("ARM OUTPUTS\n"); + for (int p = 0; p < frame_header.n_latent_n_resolutions; p++) + { + printf("ARMPYRAMID %d ", p); + m_plane_pyramid[p].print_start(0, "ARM", -1, ARM_PRECISION); + } + } } void cc_frame_decoder::run_ups(struct cc_bs_frame *frame_symbols) { struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header; - // NEW UPSAMPLE const auto time_ups_start = std::chrono::steady_clock::now(); - int ups_ks = frame_header.upsampling_kern_size; - + // full-res down to lowest-res. refine & upsample each as necessary to full res. for (int layer_number = 0, h_grid = m_gop_header.img_h, w_grid = m_gop_header.img_w; layer_number < frame_header.n_latent_n_resolutions; layer_number++, h_grid = (h_grid+1)/2, w_grid = (w_grid+1)/2) @@ -683,38 +583,93 @@ void cc_frame_decoder::run_ups(struct cc_bs_frame *frame_symbols) if (m_zero_layer[layer_number]) { // no need to upsample. just zero the final content. - m_ups_1.zero_plane_content(layer_number); + m_syn_1.zero_plane_content(layer_number); continue; } + // layer_number 0: hb_layer is (nlatents-1)%nhblayers. + // n_resolution-2 is number of hb layers max. + int preconcat_layer = (frame_header.n_latent_n_resolutions-2-layer_number)%m_ups_n_preconcat; if (layer_number == 0) { - // full res. if non-null, already present in m_ups_layers[0]. + // full res. just a refinement directly to m_syn_1, no upsampling. +#ifdef CCDECAPI_AVX2 + if (m_use_avx2) + { + ups_refine_avx2(frame_header.ups_preconcat_k_size, m_upsw_preconcat[preconcat_layer].data, m_plane_pyramid[0], m_syn_1, ARM_PRECISION, m_ups_hw); + } +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) + else +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) + { + ups_refine_cpu(frame_header.ups_preconcat_k_size, m_upsw_preconcat[preconcat_layer].data, m_plane_pyramid[0], m_syn_1, ARM_PRECISION, m_ups_hw); + } +#endif continue; } - // continually scale up to the final resolution. We upsample through lower layer numbers, - // which have already been done until the final full res destination in m_ups_1. - for (int target_layer = layer_number-1; target_layer >= 0; target_layer--) + frame_memory *ups_src = NULL; + int ups_prec = 0; + + if (layer_number == frame_header.n_latent_n_resolutions-1) + { + // just upsample, no refinement. + ups_src = &m_plane_pyramid[layer_number]; + ups_prec = ARM_PRECISION; + } + else { - frame_memory *dest; - int dest_plane; + // refine, then upsample. +#ifdef CCDECAPI_AVX2 + if (m_use_avx2) + { + ups_refine_avx2(frame_header.ups_preconcat_k_size, m_upsw_preconcat[preconcat_layer].data, m_plane_pyramid[layer_number], m_ups_h2w2, ARM_PRECISION, m_ups_hw); + } +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) + else +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) + { + ups_refine_cpu(frame_header.ups_preconcat_k_size, m_upsw_preconcat[preconcat_layer].data, m_plane_pyramid[layer_number], m_ups_h2w2, ARM_PRECISION, m_ups_hw); + } +#endif + ups_src = &m_ups_h2w2; + ups_prec = UPS_PRECISION; + } + + for (int target_layer = layer_number-1; target_layer >= 0; ups_src = &m_plane_pyramid[target_layer], target_layer--, ups_prec = UPS_PRECISION) + { + // upsample layer index to use. + int ups_layer = (frame_header.n_latent_n_resolutions-2-target_layer)%m_ups_n; + // upsample, either to next pyramid level up or, instead of to [0], to m_syn_1[layer_number]. + frame_memory *ups_dst = NULL; + int dst_plane; if (target_layer == 0) { - dest = &m_ups_1; - dest_plane = layer_number; + ups_dst = &m_syn_1; + dst_plane = layer_number; } else { - dest = &m_plane_pyramid[target_layer]; - dest_plane = 0; + ups_dst = &m_plane_pyramid[target_layer]; + dst_plane = 0; } - frame_memory *lo_res = &m_plane_pyramid[target_layer+1]; - int pad = ups_ks/2/2; - lo_res->custom_pad_replicate_plane_in_place_i(0, pad); - //printf("upslayer%dtarget%d ", layer_number, target_layer); lo_res->print_ranges("ups_in", 1, target_layer == layer_number-1 ? ARM_PRECISION : UPS_PRECISION); - // our output is not the direct origin -- we need an extra crop to point at pad start. - custom_upsample_4(ups_ks, m_upsw_t_4x4.data, m_h_pyramid[target_layer+1]+2*pad, m_w_pyramid[target_layer+1]+2*pad, lo_res->stride, - lo_res->pad_origin(pad), dest->stride, dest->pad_origin(dest_plane, 1), target_layer == layer_number-1); +#ifdef CCDECAPI_AVX2 + if (m_use_avx2) + { + ups_upsample_avx2(frame_header.ups_k_size, m_upsw[ups_layer].data, *ups_src, *ups_dst, dst_plane, ups_prec, m_ups_hw); + } +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) + else +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) + { + ups_upsample_cpu(frame_header.ups_k_size, m_upsw[ups_layer].data, *ups_src, *ups_dst, dst_plane, ups_prec, m_ups_hw); + } +#endif } } @@ -723,24 +678,162 @@ void cc_frame_decoder::run_ups(struct cc_bs_frame *frame_symbols) time_ups_seconds += hand_ups_elapsed_seconds.count(); } -// result points to either m_ups_1 or m_ups_2 -frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols) +// blend an output into an accumulator. +// syn_out = syn_out + syn_in*blend[branch_no] +frame_memory *cc_frame_decoder::run_syn_blend1(struct cc_bs_frame *frame_symbols, int n_planes, int branch_no, frame_memory *syn_in, frame_memory *syn_out) +{ + const auto time_blend_start = std::chrono::steady_clock::now(); +#if 0 +#ifdef CCDECAPI_AVX2 + if (m_use_avx2) + { + syn_blend1_avx2(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin()); + } +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) + else +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) + { + syn_blend1(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin()); + } +#endif +#endif + + if (m_verbosity >= 100) + { + printf("BLEND1 INPUT\n"); + for (int p = 0; p < n_planes; p++) + { + printf("BLENDINPLANE %d ", p); + syn_in->print_start(p, "BLENDIN", -1, UPS_PRECISION); + } + } + + syn_blend1(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin()); + const auto time_blend_end = std::chrono::steady_clock::now(); + const std::chrono::duration hand_blend_elapsed_seconds = time_blend_end - time_blend_start; + time_blend_seconds += hand_blend_elapsed_seconds.count(); + return syn_out; +} + +// blend two outputs to initialilze an accumulator. +// the 'in' is the 2nd syn output, the 'out' is the first syn output. +// we keep updating 'out' +// syn_out = syn_out*blend[branch_no-1] + syn_in*blend[branch_no] +frame_memory *cc_frame_decoder::run_syn_blend2(struct cc_bs_frame *frame_symbols, int n_planes, int branch_no, frame_memory *syn_in, frame_memory *syn_out) +{ + const auto time_blend_start = std::chrono::steady_clock::now(); +#if 0 +#ifdef CCDECAPI_AVX2 + if (m_use_avx2) + { + syn_blend2_avx2(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin(), m_syn_blends.data[branch_no-1]); + } +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) + else +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) + { + syn_blend2(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin(), m_syn_blends.data[branch_no-1]); + } +#endif +#endif + + if (m_verbosity >= 100) + { + printf("BLEND2 INPUTS\n"); + printf("1: w=%d(%f)\n", m_syn_blends.data[branch_no], (m_syn_blends.data[branch_no]/((1<print_start(p, "BLENDIN1", -1, UPS_PRECISION); + } + printf("2: w=%d(%f)\n", m_syn_blends.data[branch_no-1], (m_syn_blends.data[branch_no]/((1<print_start(p, "BLENDIN2", -1, UPS_PRECISION); + } + } + + syn_blend2(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin(), m_syn_blends.data[branch_no-1]); + const auto time_blend_end = std::chrono::steady_clock::now(); + const std::chrono::duration hand_blend_elapsed_seconds = time_blend_end - time_blend_start; + time_blend_seconds += hand_blend_elapsed_seconds.count(); + return syn_out; +} + +// operate synth layers starting from syn_in as input. +// syn_out non-NULL implies the first syn layer will EXPLICITLY target syn_out, thereafter remaining in syn_out if possible (alternating syn_out, syn_tmp if necessary) +// return value is either syn_out or syn_tmp. +// syn_out NULL implies syn processing will remain in syn_in as much as possible (alternating syn_in, syn_tmp if necessary) +// return value is either syn_in or syn_tmp. +frame_memory *cc_frame_decoder::run_syn_branch(struct cc_bs_frame *frame_symbols, int branch_no, frame_memory *syn_in, frame_memory *syn_out, frame_memory *syn_tmp) { struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header; - // NEW SYNTHESIS: IN-PLACE CONV int n_syn_in = frame_header.n_latent_n_resolutions; - frame_memory *syn_in_i = &m_ups_1; - frame_memory *syn_out_i = &m_ups_2; // if cannot do syn in-place. + frame_memory *syn_in_i = NULL; // internal in/out + frame_memory *syn_out_i = NULL; // internal in/out + if (syn_out == syn_in) + syn_out = NULL; + const auto time_syn_start = std::chrono::steady_clock::now(); for (int syn_idx = 0; syn_idx < frame_header.n_syn_layers; syn_idx++) { - int n_syn_out = frame_header.layers_synthesis[syn_idx].n_out_ft; - printf("SYN%d: k=%d #in=%d #out=%d", syn_idx, frame_header.layers_synthesis[syn_idx].ks, n_syn_in, n_syn_out); auto time_this_syn_start = std::chrono::steady_clock::now(); + int n_syn_out = frame_header.layers_synthesis[syn_idx].n_out_ft; + int syn_wb_idx = get_syn_idx(branch_no, syn_idx); + if (m_verbosity >= 3) + printf("SYN%d: k=%d #in=%d #out=%d", syn_idx, frame_header.layers_synthesis[syn_idx].ks, n_syn_in, n_syn_out); + + bool possible_in_place; +#ifdef CCDECAPI_AVX2 + if (m_use_avx2) + { + possible_in_place = frame_header.layers_synthesis[syn_idx].ks == 3 + || (syn_idx == 0 && can_fuse(frame_symbols)) + || (frame_header.layers_synthesis[syn_idx].ks == 1 && n_syn_out <= 4); + } +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) + else +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) + { + possible_in_place = true; + } +#endif + + // forcing output to not in-place for first layer? + bool in_place = possible_in_place; + if (syn_idx == 0 && syn_out != NULL) + { + in_place = false; // explit targeting of non-input for output. We will switch between syn_out, syn_tmp after as necessary. + } + + if (syn_idx == 0) + { + syn_in_i = syn_in; + syn_out_i = syn_out != NULL ? syn_out : (in_place ? syn_in_i : syn_tmp); + } + else + { + if (syn_out != NULL) + { + // syn_in_i is out or tmp. determine syn_out_i + syn_out_i = in_place ? syn_in_i : (syn_in_i == syn_out ? syn_tmp : syn_out); + } + else + { + // syn_in_i is in or tmp. determine syn_out_i + syn_out_i = in_place ? syn_in_i : (syn_in_i == syn_in ? syn_tmp : syn_in); + } + } - //syn_in_i->print_ranges("syn_in", n_syn_in, SYN_MUL_PRECISION); if (frame_header.layers_synthesis[syn_idx].ks > 1) { // pad syn_in, and direct output to syn_out. @@ -751,81 +844,88 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols) int pad_origin_idx = origin_idx-pad*syn_in_i->stride-pad; for (int l = 0; l < n_syn_in; l++) syn_in_i->custom_pad_replicate_plane_in_place_i(l, pad); - printf(" padded "); + if (m_verbosity >= 3) + printf(" padded "); #ifdef CCDECAPI_AVX2 - bool in_place = false; - if (frame_header.layers_synthesis[syn_idx].ks == 3) + if (m_use_avx2) { - // optimized 3x3 with line buffer. - in_place = true; - m_syn3x3_linebuffer.update_to(2*n_syn_out*m_gop_header.img_w); // two lines for now, could eventually be 1 line plus a few pixels. - if (n_syn_in == 3 && n_syn_out == 3) - custom_conv_ks3_in3_out3_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data, - h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL, - m_syn3x3_linebuffer.data, - !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if (n_syn_in == 6 && n_syn_out == 6) - custom_conv_ks3_in6_out6_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data, - h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL, - m_syn3x3_linebuffer.data, - !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if (n_syn_in == 9 && n_syn_out == 6) - custom_conv_ks3_in9_out6_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data, - h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL, - m_syn3x3_linebuffer.data, - !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if (n_syn_in == 9 && n_syn_out == 9) - custom_conv_ks3_in9_out9_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data, - h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL, - m_syn3x3_linebuffer.data, - !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + if (frame_header.layers_synthesis[syn_idx].ks == 3) + { + // optimized 3x3 with line buffer. + // in_place = true; + m_syn3x3_linebuffer.update_to(2*n_syn_out*m_gop_header.img_w); // two lines for now, could eventually be 1 line plus a few pixels. + if (n_syn_in == 3 && n_syn_out == 3) + custom_conv_ks3_in3_out3_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, + h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), + m_syn3x3_linebuffer.data, + !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if (n_syn_in == 6 && n_syn_out == 6) + custom_conv_ks3_in6_out6_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, + h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), + m_syn3x3_linebuffer.data, + !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if (n_syn_in == 9 && n_syn_out == 6) + custom_conv_ks3_in9_out6_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, + h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), + m_syn3x3_linebuffer.data, + !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if (n_syn_in == 9 && n_syn_out == 9) + custom_conv_ks3_in9_out9_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, + h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), + m_syn3x3_linebuffer.data, + !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else + custom_conv_ks3_inX_outX_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, + h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), + m_syn3x3_linebuffer.data, + !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + } else - custom_conv_ks3_inX_outX_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data, - h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL, - m_syn3x3_linebuffer.data, + { + custom_conv_ksX_inX_outX_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, + h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + } } +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) else - custom_conv_ksX_inX_outX_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data, - h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), - !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); -#else - bool in_place = false; - if (frame_header.layers_synthesis[syn_idx].ks == 3) +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) { - in_place = true; - m_syn3x3_linebuffer.update_to(2*n_syn_out*m_gop_header.img_w); // two lines for now, could eventually be 1 line plus a few pixels. - custom_conv_ks3_inX_outX_lb(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data, - h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL, - m_syn3x3_linebuffer.data, - !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + if (frame_header.layers_synthesis[syn_idx].ks == 3) + { + m_syn3x3_linebuffer.update_to(2*n_syn_out*m_gop_header.img_w); // two lines for now, could eventually be 1 line plus a few pixels. + custom_conv_ks3_inX_outX_lb(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, + h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), + m_syn3x3_linebuffer.data, + !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + } + else + { + custom_conv_ksX_inX_outX(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, + h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), + !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + } } - else - custom_conv_ksX_inX_outX(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data, - h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(), - !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); #endif if (!in_place) { - // swap our in & out. - frame_memory *tmp = syn_in_i; - syn_in_i = syn_out_i; - syn_out_i = tmp; + // swap our in & out. accounting for any forced-first-time non-in-place. + // we just update syn_in_i here, syn_out_i is calculated every loop. + syn_in_i = syn_out_i; } } else { // possible in-place if ATATIME is big enough to buffer all outputs for single spatial pixel. // assuming max 4 here (avx2) and infinite (cpu) - bool in_place = n_syn_out <= 4; // !!! hardwired, not too many outputs. - // possible fusion. bool fused = syn_idx == 0 && can_fuse(frame_symbols); if (fused) { // compatible! - in_place = true; int n_hidden = n_syn_out; n_syn_out = frame_header.layers_synthesis[syn_idx+1].n_out_ft; @@ -833,55 +933,82 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols) int32_t synw_fused[n_syn_in*n_hidden]; int kidx = 0; for (int kx = 0; kx < n_syn_in; kx++) + { for (int ky = 0; ky < n_hidden; ky++) { - synw_fused[kidx++] = m_synw[syn_idx].data[ky*n_syn_in+kx]; + synw_fused[kidx++] = m_synw[syn_wb_idx].data[ky*n_syn_in+kx]; } -#ifdef CCDECAPI_AVX2 - if (n_hidden == 40) + } +#if defined(CCDECAPI_AVX2) + if (m_use_avx2 && n_syn_in == 7) { - if (n_syn_out == 3) - custom_conv_ks1_in7_hidden40_out3_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin()); - else if (n_syn_out == 6) - custom_conv_ks1_in7_hidden40_out6_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin()); + if (n_hidden == 48 && n_syn_out == 3) + custom_conv_ks1_in7_hidden48_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); + else if (n_hidden == 40) + { + if (n_syn_out == 3) + custom_conv_ks1_in7_hidden40_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); + else if (n_syn_out == 6) + custom_conv_ks1_in7_hidden40_out6_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); + else if (n_syn_out == 9) + custom_conv_ks1_in7_hidden40_out9_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); + else + custom_conv_ks1_inX_hiddenX_outX(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); + } + else if (n_hidden == 32 && n_syn_out == 3) + custom_conv_ks1_in7_hidden32_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); + else if (n_hidden == 16 && n_syn_out == 3) + custom_conv_ks1_in7_hidden16_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); + else if (n_hidden == 8 && n_syn_out == 3) + custom_conv_ks1_in7_hidden8_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); else - custom_conv_ks1_in7_hidden40_out9_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin()); + custom_conv_ks1_inX_hiddenX_outX(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); } - else if (n_hidden == 16) - custom_conv_ks1_in7_hidden16_out3_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin()); +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) else - custom_conv_ks1_in7_hidden8_out3_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin()); -#else - custom_conv_ks1_inX_hiddenX_outX(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin()); +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) + custom_conv_ks1_inX_hiddenX_outX(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin()); #endif } else { // 1x1 kernel but not fused. #ifdef CCDECAPI_AVX2 - if (n_syn_in == 7 && n_syn_out == 9) - custom_conv_ks1_in7_out9_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if (n_syn_in == 9 && n_syn_out == 3) - custom_conv_ks1_in9_out3_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if (n_syn_in == 9 && n_syn_out == 6) - custom_conv_ks1_in9_out6_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if (n_syn_in == 9 && n_syn_out == 9) - custom_conv_ks1_in9_out9_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if (n_syn_in == 7 && n_syn_out == 40) - custom_conv_ks1_in7_out40_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if (n_syn_in == 7 && n_syn_out == 16) - custom_conv_ks1_in7_out16_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if ( n_syn_out == 3) - custom_conv_ks1_inX_out3_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if ( n_syn_out == 6) - custom_conv_ks1_inX_out6_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else if ( n_syn_out == 9) - custom_conv_ks1_inX_out9_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); - else - custom_conv_ks1_inX_outX_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); -#else - in_place = true; - custom_conv_ks1_inX_outX(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + if (m_use_avx2) + { + if (n_syn_in == 7 && n_syn_out == 9) + custom_conv_ks1_in7_out9_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if (n_syn_in == 9 && n_syn_out == 3) + custom_conv_ks1_in9_out3_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if (n_syn_in == 9 && n_syn_out == 6) + custom_conv_ks1_in9_out6_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if (n_syn_in == 9 && n_syn_out == 9) + custom_conv_ks1_in9_out9_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if (n_syn_in == 7 && n_syn_out == 40) + custom_conv_ks1_in7_out40_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if (n_syn_in == 7 && n_syn_out == 16) + custom_conv_ks1_in7_out16_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if ( n_syn_out == 3) + custom_conv_ks1_inX_out3_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if ( n_syn_out == 6) + custom_conv_ks1_inX_out6_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else if ( n_syn_out == 9) + custom_conv_ks1_inX_out9_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + else + custom_conv_ks1_inX_outX_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + } +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) + else +#endif +#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU) + { + // can buffer everything in cpu-mode. + // in_place = true; + custom_conv_ks1_inX_outX(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu); + } #endif } if (fused) @@ -891,10 +1018,9 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols) } if (!in_place) { - // swap our in & out. - frame_memory *tmp = syn_in_i; + // swap our in & out. accounting for any forced-first-time non-in-place. + // we just update syn_in_i here, syn_out_i is calculated every loop. syn_in_i = syn_out_i; - syn_out_i = tmp; } } @@ -904,9 +1030,9 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols) const auto time_this_syn_end = std::chrono::steady_clock::now(); const std::chrono::duration this_syn_elapsed_seconds = time_this_syn_end - time_this_syn_start; - printf(" %g\n", (double)this_syn_elapsed_seconds.count()); + if (m_verbosity >= 3) + printf(" %g\n", (double)this_syn_elapsed_seconds.count()); } - //syn_in_i->print_ranges("final", n_syn_in, SYN_MUL_PRECISION); const auto time_syn_end = std::chrono::steady_clock::now(); const std::chrono::duration syn_elapsed_seconds = time_syn_end - time_syn_start; @@ -915,6 +1041,113 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols) return syn_in_i; } +frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols) +{ + // multiple synthesiser runs. + // we have m_syn_1 containing upsampled output, ready to send to synthesisers. + // + // ASSUMING ALL SYNTHS POSSIBLE IN-PLACE. IE, ALL KS ARE 1 or 3. + // with n branches 1..n: + // for branch1: + // synth m_syn_1 -> m_syn_2. (ie, 1st syn layer takes 1->2, remaining layers in 2) + // for branch 2..n-1: + // synth m_syn_1 -> m_syn_3. (ie, 1st syn layer takes 1->3, remaining layers in 3) + // blend result m_syn_3 into m_syn_2. + // for branch n: + // synth m_syn_1 in place. (ie, all layers in 1) + // blend result m_syn_1 into m_syn_2. + // + // result is m_syn_2. + + // IF NOT IN-PLACE (KS > 3) + // with n branches 1..n: + // for branch1: + // synth m_syn_1 -> m_syn_2 or m_syn_x. (ie, 1st syn layer 1->2, remaining layers in 2/x) + // for branch 2..n-1: + // synth m_syn_1 -> m_syn_3 or m_syn_x. (ie, 1st syn layer 1->3, remaining layers in 3/x) + // blend result m_syn_3/x into m_syn_2. + // for branch n: + // synth m_syn_1 in place or m_syn_x . (ie, all layers in 1) + // blend result m_syn_1 into m_syn_2. + + frame_memory *p_syn_1 = &m_syn_1; // initial upsample output -- preserved until last branch. + frame_memory *p_syn_2 = &m_syn_2; // initial branch output, blended into by subsequent branches. + frame_memory *p_syn_3 = &m_syn_3; // subsequence branch outputs + frame_memory *p_syn_tmp = &m_syn_tmp; + + if (m_verbosity >= 100) + { + struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header; + printf("SYNTHESIS INPUTS\n"); + for (int p = 0; p < frame_header.n_latent_n_resolutions; p++) + { + printf("SYNPLANE %d ", p); + m_syn_1.print_start(p, "SYNIN", -1, UPS_PRECISION); + } + } + + frame_memory *result = NULL; + for (int b = 0; b < m_syn_n_branches; b++) + { + if (b == 0 && m_syn_n_branches > 1) + { + result = run_syn_branch(frame_symbols, b, p_syn_1, p_syn_2, p_syn_tmp); + // result is either m_syn_2 or m_syn_tmp. + if (result != p_syn_2) + { + // switch our notion of m_syn_2 and m_syn_tmp. + p_syn_tmp = p_syn_2; + p_syn_2 = result; + } + } + else if (b < m_syn_n_branches-1) + { + result = run_syn_branch(frame_symbols, b, p_syn_1, p_syn_3, p_syn_tmp); + // result is either m_syn_3 or m_syn_tmp. + if (result != p_syn_3) + { + // switch our notion of m_syn_3 and m_syn_tmp. + p_syn_tmp = p_syn_3; + p_syn_3 = result; + } + if (b == 1) + run_syn_blend2(frame_symbols, 3, b, p_syn_3, p_syn_2); + else + run_syn_blend1(frame_symbols, 3, b, p_syn_3, p_syn_2); + } + else + { + // NULL for out implies we can run in-place as much as we like, destroying p_syn_1. + result = run_syn_branch(frame_symbols, b, p_syn_1, NULL, p_syn_tmp); + // result is either m_syn_1 or m_syn_tmp. + if (result != p_syn_1) + { + // switch our notion of m_syn_1 and m_syn_tmp. + p_syn_tmp = p_syn_1; + p_syn_1 = result; + } + if (b == 0) + { + ; // we don't have any blending, single branch. + } + else if (b == 1) + { + // 2 branches being blended to create result. + run_syn_blend2(frame_symbols, 3, b, p_syn_1, p_syn_2); + result = p_syn_2; + } + else + { + // subsequent branch being blended into result. + run_syn_blend1(frame_symbols, 3, b, p_syn_1, p_syn_2); + result = p_syn_2; + } + } + } + + return result; +} + // returned value should be transformed or otherwise copied before calling decode_frame again. struct frame_memory *cc_frame_decoder::decode_frame(struct cc_bs_frame *frame_symbols) { @@ -922,8 +1155,11 @@ struct frame_memory *cc_frame_decoder::decode_frame(struct cc_bs_frame *frame_sy read_ups(frame_symbols); read_syn(frame_symbols); - printf("loaded weights\n"); - fflush(stdout); + if (m_verbosity >= 3) + { + printf("loaded weights\n"); + fflush(stdout); + } check_allocations(frame_symbols); diff --git a/coolchic/cpp/cc-frame-decoder.h b/coolchic/cpp/cc-frame-decoder.h index 51368941..9cc719ef 100644 --- a/coolchic/cpp/cc-frame-decoder.h +++ b/coolchic/cpp/cc-frame-decoder.h @@ -5,20 +5,52 @@ class cc_frame_decoder { public: - cc_frame_decoder(struct cc_bs_gop_header &gop_header) + cc_frame_decoder(struct cc_bs_gop_header &gop_header, int output_bitdepth, int output_chroma_format +#if defined(CCDECAPI_AVX2_OPTIONAL) + , bool use_avx2 +#endif + , int verbosity + ) : m_gop_header(gop_header), + m_output_bitdepth(output_bitdepth), + m_output_chroma_format(output_chroma_format), m_mlpw_t(NULL), m_mlpb(NULL), m_mlp_n_hidden_layers_arm(-1), + m_ups_n(-1), + m_upsw(NULL), + m_ups_n_preconcat(-1), + m_upsw_preconcat(NULL), + m_syn_n_branches(-1), + m_syn_n_layers(-1), m_synw(NULL), - m_synb(NULL), - m_syn_n_layers(-1) - {}; - - ~cc_frame_decoder() {delete[] m_mlpw_t; delete[] m_mlpb; delete[] m_synw; delete[] m_synb;}; + m_synb(NULL) +#if defined(CCDECAPI_AVX2) +#if defined(CCDECAPI_AVX2_OPTIONAL) + ,m_use_avx2(use_avx2) +#else + ,m_use_avx2(true) +#endif +#endif + ,m_verbosity(verbosity) + ,m_arm_pad(0) + ,m_ups_pad(0) + ,m_max_pad(0) + { + if (m_output_bitdepth == 0) + m_output_bitdepth = m_gop_header.bitdepth; + if (m_output_chroma_format == 0) + m_output_chroma_format = m_gop_header.frame_data_type == 1 ? 420 : 444; + }; + + ~cc_frame_decoder() {delete[] m_mlpw_t; delete[] m_mlpb; delete[] m_upsw; delete[] m_upsw_preconcat; delete[] m_synw; delete[] m_synb;}; +public: + int get_syn_idx(int branch_idx, int layer_idx) { return branch_idx*m_syn_n_layers + layer_idx; } public: struct cc_bs_gop_header &m_gop_header; + int m_output_bitdepth; + int m_output_chroma_format; public: struct frame_memory *decode_frame(struct cc_bs_frame *frame_symbols); @@ -30,14 +62,24 @@ class cc_frame_decoder weights_biases m_mlpbOUT; int m_mlp_n_hidden_layers_arm; - weights_biases m_upsw_t_4x4; + int m_ups_n; + weights_biases *m_upsw; + int m_ups_n_preconcat; + weights_biases *m_upsw_preconcat; - weights_biases *m_synw; - weights_biases *m_synb; + int m_syn_n_branches; int m_syn_n_layers; + weights_biases m_syn_blends; // m_syn_n_branches blend values. + weights_biases *m_synw; // [branchidx*m_syn_n_layers+synidx] + weights_biases *m_synb; // [branchidx*m_syn_n_layers+synidx] buffer m_syn3x3_linebuffer; +#if defined(CCDECAPI_AVX2) + bool m_use_avx2; +#endif + int m_verbosity; + private: // for arm decode destination, and upsampling. std::vector m_h_pyramid; @@ -49,8 +91,12 @@ class cc_frame_decoder int m_ups_pad; int m_max_pad; - frame_memory m_ups_1; // !!! really for_syn_0 - frame_memory m_ups_2; // !!! really for_syn_1 + frame_memory m_ups_h2w2; // internal refinement target prior to upsample. + frame_memory m_ups_hw; // during 2-pass refinement and 2-padd upsample. + frame_memory m_syn_1; + frame_memory m_syn_2; + frame_memory m_syn_3; + frame_memory m_syn_tmp; private: // set by arm to indicate no content, avoid ups for no content. @@ -65,6 +111,9 @@ class cc_frame_decoder void run_arm(struct cc_bs_frame *frame_symbols); void run_ups(struct cc_bs_frame *frame_symbols); + frame_memory *run_syn_blend1(struct cc_bs_frame *frame_symbols, int n_planes, int branch_no, frame_memory *syn_in, frame_memory *syn_out); + frame_memory *run_syn_blend2(struct cc_bs_frame *frame_symbols, int n_planes, int branch_no, frame_memory *syn_in, frame_memory *syn_out); + frame_memory *run_syn_branch(struct cc_bs_frame *frame_symbols, int branch_no, frame_memory *syn_in, frame_memory *syn_out, frame_memory *syn_tmp); frame_memory *run_syn(struct cc_bs_frame *frame_symbols); }; diff --git a/coolchic/cpp/ccdecapi.hpp b/coolchic/cpp/ccdecapi.cpp similarity index 77% rename from coolchic/cpp/ccdecapi.hpp rename to coolchic/cpp/ccdecapi.cpp index 905d9610..e754766c 100644 --- a/coolchic/cpp/ccdecapi.hpp +++ b/coolchic/cpp/ccdecapi.cpp @@ -4,8 +4,14 @@ * #define CCDECAPI_CPU * or * #define CCDECAPI_AVX2 + * or + * #define CCDECAPI_AVX2_OPTIONAL */ +#if defined(CCDECAPI_AVX2_OPTIONAL) +#define CCDECAPI_AVX2 +#endif + #include #include #include @@ -28,55 +34,97 @@ float time_bac_seconds = 0.0; float time_arm_seconds = 0.0; float time_ups_seconds = 0.0; float time_syn_seconds = 0.0; +float time_blend_seconds = 0.0; float time_warp_seconds = 0.0; float time_bpred_seconds = 0.0; float time_all_seconds = 0.0; +// like std::byteswap, but that's only in c++23 onwards! +inline unsigned short byteswap(unsigned short x) +{ + unsigned short hi = x&0xFF00; + unsigned short lo = x&0x00FF; + return (lo<<8)|(hi>>8); +} + +// like ends_with in c++20 and onwards. +inline bool ends_with(const std::string& a, const std::string& b) +{ + if (b.size() > a.size()) + return false; + return std::equal(a.begin() + a.size() - b.size(), a.end(), b.begin()); +} + // only use nc == 3 -void ppm_out(int nc, int pixel_depth, struct frame_memory &in_info, char const *outname) +void ppm_out(int nc, int bit_depth, struct frame_memory &in_info, FILE *fout) { int const h = in_info.h; int const w = in_info.w; int const stride = in_info.stride; int const plane_stride = in_info.plane_stride; + int const max_sample_val = (1<> SYN_LAYER_PRECISION; - int g = ((*inG++)*255+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION; - int b = ((*inB++)*255+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION; + int r = ((*inR++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION; + int g = ((*inG++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION; + int b = ((*inB++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION; if (r < 0) r = 0; if (g < 0) g = 0; if (b < 0) b = 0; - if (r > 255) r = 255; - if (g > 255) g = 255; - if (b > 255) b = 255; + if (r > max_sample_val) r = max_sample_val; + if (g > max_sample_val) g = max_sample_val; + if (b > max_sample_val) b = max_sample_val; unsigned char pix[3]; pix[0] = r; pix[1] = g; pix[2] = b; - fwrite(pix, 3, 1, fout); + fwrite(pix, 3, sizeof(pix[0]), fout); + } + } + else + { + for (int y = 0; y < h; y++, inR += stride-w, inG += stride-w, inB += stride-w) + for (int x = 0; x < w; x++) + { + // precision is SYN_LAYER_PRECISION + int r = ((*inR++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION; + int g = ((*inG++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION; + int b = ((*inB++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION; + if (r < 0) r = 0; + if (g < 0) g = 0; + if (b < 0) b = 0; + if (r > max_sample_val) r = max_sample_val; + if (g > max_sample_val) g = max_sample_val; + if (b > max_sample_val) b = max_sample_val; + unsigned short pix[3]; + pix[0] = byteswap(r); + pix[1] = byteswap(g); + pix[2] = byteswap(b); + fwrite(pix, 3, sizeof(pix[0]), fout); } } - fclose(fout); + fflush(fout); } // we 'stabilise' the yuv to (0..255) as integers, as well as uv subsample. @@ -294,7 +342,7 @@ unsigned short *get_raw_444_10b(struct frame_memory &in) // incoming 420 is planar, outgoing 444 is planar. void convert_420_444_8b(frame_memory &out, unsigned char *in, int h, int w) { - out.update_to(h, w, 3, 0); + out.update_to(h, w, 0, 3); unsigned char *src = in; int32_t *dst = out.plane_origin(0); @@ -329,7 +377,7 @@ void convert_420_444_8b(frame_memory &out, unsigned char *in, int h, int w) // incoming 420 is planar, outgoing 444 is planar. void convert_420_444_10b(frame_memory &out, unsigned short *in, int h, int w) { - out.update_to(h, w, 3, 0); + out.update_to(h, w, 0, 3); unsigned short *src = in; int32_t *dst = out.plane_origin(0); @@ -397,7 +445,7 @@ void dump_yuv444_10b(unsigned short *frame_444, int h, int w, FILE *fout, int fr // incoming 444 is planar, outgoing 444 is planar. void store_444_8b(frame_memory &out, unsigned char *in, int h, int w) { - out.update_to(h, w, 3, 0); + out.update_to(h, w, 0, 3); unsigned char *src = in; int32_t *dst = out.plane_origin(0); @@ -421,7 +469,7 @@ void store_444_8b(frame_memory &out, unsigned char *in, int h, int w) // incoming 444 is planar, outgoing 444 is planar. void store_444_10b(frame_memory &out, unsigned short *in, int h, int w) { - out.update_to(h, w, 3, 0); + out.update_to(h, w, 0, 3); unsigned short *src = in; int32_t *dst = out.plane_origin(0); @@ -449,7 +497,7 @@ void warp(struct frame_memory &warp_result, struct frame_memory &raw_info, int r { const auto time_warp_start = std::chrono::steady_clock::now(); - warp_result.update_to(ref.h, ref.w, 3, 0); + warp_result.update_to(ref.h, ref.w, 0, 3); int32_t *src = ref.origin(); int const src_stride = ref.stride; @@ -560,7 +608,7 @@ void bpred(struct frame_memory &bpred_result, struct frame_memory &raw_info, int int raw_stride = raw_info.stride; int raw_plane_stride = raw_info.plane_stride; - bpred_result.update_to(h, w, 3, 0); + bpred_result.update_to(h, w, 0, 3); int32_t *out = bpred_result.origin(); int32_t *src0 = pred0.origin(); @@ -622,14 +670,14 @@ void process_inter(struct frame_memory &pred, struct frame_memory &raw_cc_output } } -#ifdef CCDECAPI_CPU -int cc_decode_cpu(std::string &bitstream_filename, std::string &out_filename) -#else -#ifdef CCDECAPI_AVX2 -int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) +#if defined(CCDECAPI_CPU) +int cc_decode_cpu(std::string &bitstream_filename, std::string &out_filename, int output_bitdepth, int output_chroma_format, int verbosity) +#elif defined(CCDECAPI_AVX2_OPTIONAL) +int cc_decode_avx2_optional(std::string &bitstream_filename, std::string &out_filename, int output_bitdepth, int output_chroma_format, bool use_avx2, int verbosity) +#elif defined(CCDECAPI_AVX2) +int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename, int output_bitdepth, int output_chroma_format, int verbosity) #else -#error must have one of CCDECAPI_CPU or CCDECAPI_AVX2 defined. -#endif +#error must have one of CCDECAPI_CPU, CCDECAPI_AVX2 or CCDECAPI_AVX2_OPTIONAL defined. #endif { if (bitstream_filename == "") @@ -641,7 +689,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) const auto time_all_start = std::chrono::steady_clock::now(); cc_bs bs; - if (!bs.open(bitstream_filename)) + if (!bs.open(bitstream_filename, verbosity)) { printf("cannot open %s for reading\n", bitstream_filename.c_str()); return 1; @@ -661,7 +709,11 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) } } - struct cc_frame_decoder frame_decoder(gop_header); +#if defined(CCDECAPI_AVX2_OPTIONAL) + struct cc_frame_decoder frame_decoder(gop_header, output_bitdepth, output_chroma_format, use_avx2, verbosity); +#else + struct cc_frame_decoder frame_decoder(gop_header, output_bitdepth, output_chroma_format, verbosity); +#endif for (int frame_coding_idx = 0; frame_coding_idx <= gop_header.intra_period; frame_coding_idx++) { // raw_cc_output either @@ -670,7 +722,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) // [9] (B: residue+xy+alpha+xy+beta) // from bitstream. - cc_bs_frame *frame_symbols = bs.decode_frame(); + cc_bs_frame *frame_symbols = bs.decode_frame(verbosity); if (frame_symbols == NULL) { return 1; @@ -706,7 +758,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) if (frames_444[idx].raw() != NULL) { ref_prev = &frames_444[idx]; - printf("refprev: %d\n", idx); + //printf("refprev: %d\n", idx); break; } // find next if B. @@ -715,7 +767,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) if (frames_444[idx].raw() != NULL) { ref_next = &frames_444[idx]; - printf("refnext: %d\n", idx); + //printf("refnext: %d\n", idx); break; } process_inter(frame_444, *raw_cc_output, gop_header.img_h, gop_header.img_w, ref_prev, ref_next, frame_header.flow_gain); @@ -723,9 +775,9 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) } // YUV 420 - if (gop_header.frame_data_type == 1) + if (ends_with(out_filename, ".yuv") && frame_decoder.m_output_chroma_format == 420) { - if (gop_header.bitdepth == 8) + if (frame_decoder.m_output_bitdepth == 8) { unsigned char *frame_420 = convert_444_420_8b(*frame_444p); if (out_filename != "") @@ -733,7 +785,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) convert_420_444_8b(frames_444[frame_header.display_index], frame_420, gop_header.img_h, gop_header.img_w); delete[] frame_420; } - else if (gop_header.bitdepth == 10) + else if (frame_decoder.m_output_bitdepth == 10) { unsigned short *frame_420 = convert_444_420_10b(*frame_444p); if (out_filename != "") @@ -744,13 +796,14 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) } else { - printf("Unkown YUV bitdepth %d. Should be 8 or 10.", gop_header.bitdepth); + printf("Unknown YUV bitdepth %d. Should be 8 or 10.\n", frame_decoder.m_output_bitdepth); + exit(1); } } // YUV 444 - else if (gop_header.frame_data_type == 2) + else if (ends_with(out_filename, ".yuv") && frame_decoder.m_output_chroma_format == 444) { - if (gop_header.bitdepth == 8) + if (frame_decoder.m_output_bitdepth == 8) { unsigned char *raw_frame_444 = get_raw_444_8b(*frame_444p); if (out_filename != "") @@ -758,7 +811,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) store_444_8b(frames_444[frame_header.display_index], raw_frame_444, gop_header.img_h, gop_header.img_w); } - else if (gop_header.bitdepth == 10) + else if (frame_decoder.m_output_bitdepth == 10) { unsigned short *raw_frame_444 = get_raw_444_10b(*frame_444p); if (out_filename != "") @@ -767,15 +820,15 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) } else { - printf("Unkown YUV bitdepth %d. Should be 8 or 10.", gop_header.bitdepth); + printf("Unknown YUV bitdepth %d. Should be 8 or 10.\n", frame_decoder.m_output_bitdepth); + exit(1); } } - else { // rgb if (out_filename != "") - ppm_out(3, 3, *frame_444p, out_filename.c_str()); + ppm_out(3, frame_decoder.m_output_bitdepth, *frame_444p, fout); if (gop_header.intra_period > 0) { printf("do not want to copy rgb in rgb video\n"); @@ -788,23 +841,39 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename) const std::chrono::duration elapsed_all = (time_all_end-time_all_start); time_all_seconds = (float)elapsed_all.count(); - printf("time: arm %g ups %g syn %g warp %g bpred %g all %g\n", - time_arm_seconds, time_ups_seconds, time_syn_seconds, time_warp_seconds, time_bpred_seconds, time_all_seconds); + if (verbosity >= 1) + printf("time: arm %g ups %g syn %g blend %g warp %g bpred %g all %g\n", + time_arm_seconds, time_ups_seconds, time_syn_seconds, time_blend_seconds, time_warp_seconds, time_bpred_seconds, time_all_seconds); fflush(stdout); } if (fout != NULL) + { fclose(fout); - - printf("decode done\n"); + printf("%s created\n", out_filename.c_str()); + } return 0; } -#if 0 +#ifdef CCDEC_EXE + +// A main() and associated parameters for standalone executable. +#ifndef CCDECAPI_AVX2_OPTIONAL +// we are expecting 'optional' -- we have a run-time parameter for choosing cpu or avx2 or auto. +#error CCDEC_EXE needs CCDECAPI_AVX2_OPTIONAL +no. +#endif + char const *param_bitstream = "--input="; char const *param_out = "--output="; +char const *param_bitdepth = "--output_bitdepth="; +char const *param_chroma = "--output_chroma_format="; +char const *param_cpu = "--cpu"; +char const *param_auto = "--auto"; +char const *param_avx2 = "--avx2"; +char const *param_v = "--v="; void usage(char const *msg = NULL) { @@ -812,7 +881,11 @@ void usage(char const *msg = NULL) printf("%s\n", msg); printf("Usage:\n"); printf(" %s: .hevc to decode\n", param_bitstream); - printf(" %s: reconstruction if desired: ppm (image) or yuv420 (video) only\n", param_out); + printf(" [%s]: reconstruction if desired: ppm (image) or yuv420 (video) only\n", param_out); + printf(" [%s]: output bitdepth 0 => take from bitstream\n", param_bitdepth); + printf(" [%s]: output chroma subsampling 420 or 444 for yuv; 0 => take from bitstream\n", param_chroma); + printf(" [%s|%s|%s]: optimized instruction set with which to decode\n", param_cpu, param_avx2, param_auto); + printf(" [%s]: verbosity\n", param_v); exit(msg != NULL); } @@ -821,6 +894,12 @@ int main(int argc, const char* argv[]) std::string bitstream_filename; std::string out = "out"; + int output_bitdepth = 0; + int output_chroma_format = 0; + bool explicit_instruction_set = false; + bool use_avx2 = false; + bool use_auto = false; + int verbosity = 0; for (int i = 1; i < argc; i++) { @@ -828,6 +907,27 @@ int main(int argc, const char* argv[]) bitstream_filename = argv[i]+strlen(param_bitstream); else if (strncmp(argv[i], param_out, strlen(param_out)) == 0) out = argv[i]+strlen(param_out); + else if (strncmp(argv[i], param_bitdepth, strlen(param_bitdepth)) == 0) + output_bitdepth = atoi(argv[i]+strlen(param_bitdepth)); + else if (strncmp(argv[i], param_chroma, strlen(param_chroma)) == 0) + output_chroma_format = atoi(argv[i]+strlen(param_chroma)); + else if (strncmp(argv[i], param_avx2, strlen(param_avx2)) == 0 + || strncmp(argv[i], param_cpu, strlen(param_cpu)) == 0 + || strncmp(argv[i], param_auto, strlen(param_auto)) == 0) + { + if (explicit_instruction_set) + { + printf("%s: only a single --auto, --avx2 or --cpu\n", argv[i]); + exit(1); + } + explicit_instruction_set = true; + use_avx2 = strncmp(argv[i], param_avx2, strlen(param_avx2)) == 0; + use_auto = strncmp(argv[i], param_auto, strlen(param_auto)) == 0; + } + else if (strncmp(argv[i], param_v, strlen(param_v)) == 0) + { + verbosity = atoi(argv[i]+strlen(param_v)); + } else { std::string error_message = "unknown parameter "; @@ -839,12 +939,16 @@ int main(int argc, const char* argv[]) if (bitstream_filename == "") usage("must specify a bitstream"); -#ifdef CCDECAPI_CPU - int result = cc_decode_cpu(bitstream_filename, out); -#else - int result = cc_decode_avx2(bitstream_filename, out); -#endif - printf("decode exit code: %d\n", result); + if (!explicit_instruction_set) + use_auto = true; + if (use_auto) + { + if (__builtin_cpu_supports("avx2")) + { + use_avx2 = true; + } + } + int result = cc_decode_avx2_optional(bitstream_filename, out, output_bitdepth, output_chroma_format, use_avx2, verbosity); return result; } #endif diff --git a/coolchic/cpp/ccdecapi_avx2.cpp b/coolchic/cpp/ccdecapi_avx2.cpp index adb0f078..9f9fd3b6 100644 --- a/coolchic/cpp/ccdecapi_avx2.cpp +++ b/coolchic/cpp/ccdecapi_avx2.cpp @@ -19,13 +19,18 @@ namespace py = pybind11; // encode latents layer to a file. int cc_decode_avx2( std::string &in_bitstream_filename, - std::string &out_ppm_filename); + std::string &out_ppm_filename, + int output_bitdepth = 0, + int output_chroma_format = 0, + int verbosity = 0); PYBIND11_MODULE(ccdecapi_avx2, m) { m.doc() = "ccdecoding"; // optional module docstring m.def("cc_decode_avx2", &cc_decode_avx2, "decode a bitstream"); } +#ifndef CCDECAPI_AVX2 #define CCDECAPI_AVX2 -#include "ccdecapi.hpp" +#endif +#include "ccdecapi.cpp" #undef CCDECAPI_AVX2 diff --git a/coolchic/cpp/ccdecapi_cpu.cpp b/coolchic/cpp/ccdecapi_cpu.cpp index 918c8c44..e5a7bbca 100644 --- a/coolchic/cpp/ccdecapi_cpu.cpp +++ b/coolchic/cpp/ccdecapi_cpu.cpp @@ -19,13 +19,18 @@ namespace py = pybind11; // encode latents layer to a file. int cc_decode_cpu( std::string &in_bitstream_filename, - std::string &out_ppm_filename); + std::string &out_ppm_filename, + int output_bitdepth = 0, + int output_chroma_format = 0, + int verbosity = 0); PYBIND11_MODULE(ccdecapi_cpu, m) { m.doc() = "ccdecoding"; // optional module docstring m.def("cc_decode_cpu", &cc_decode_cpu, "decode a bitstream"); } +#ifndef CCDECAPI_CPU #define CCDECAPI_CPU -#include "ccdecapi.hpp" +#endif +#include "ccdecapi.cpp" #undef CCDECAPI_CPU diff --git a/coolchic/cpp/ccencapi.cpp b/coolchic/cpp/ccencapi.cpp index 08272ec0..ce255184 100644 --- a/coolchic/cpp/ccencapi.cpp +++ b/coolchic/cpp/ccencapi.cpp @@ -96,6 +96,7 @@ void code_val(TEncBinCABAC &layer_BAC, MuSigGTs *coding_ctxs, int val_to_code) // we return the best index. int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count) { +#if 0 // !!! check for all zero, emit empty file. bool all_zero = true; for (int i = 0; i < (int)xs.size(); i++) @@ -106,10 +107,11 @@ int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count) break; } } - if (all_zero) - { - printf("all weights/biases zero -- think of empty file\n"); - } + //if (all_zero) + //{ + // printf("all weights/biases zero -- think of empty file\n"); + //} +#endif TEncBinCABAC layer_BAC; @@ -120,6 +122,16 @@ int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count) int test_max = 12; if (use_count >= 0) test_min = test_max = use_count; +#if 0 + if (1) // xs.size() == 5 || xs.size() == 8) + { + printf("encoding weight/bias vector %d to %s:", xs.size(), out_name.c_str()); + for (int i = 0; i < xs.size(); i++) + printf(" %d", xs[i]); + printf("\n"); + fflush(stdout); + } +#endif for (int exgolomb_count = test_min; exgolomb_count <= test_max; exgolomb_count++) { //auto layer_BAC = CABACEncoder(); @@ -141,7 +153,7 @@ int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count) { best_exgolomb_count = exgolomb_count; best_exgolomb_bytes = bsBAC.getFifo(); - printf("better exgolomb bytes %d at count=%d\n", (int)best_exgolomb_bytes.size(), best_exgolomb_count); + //printf("better exgolomb bytes %d at count=%d\n", (int)best_exgolomb_bytes.size(), best_exgolomb_count); } } @@ -158,7 +170,7 @@ int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count) exit(1); } fclose(fout); - printf("%s created\n", out_name.c_str()); + //printf("%s created\n", out_name.c_str()); // return best sig index used for coding return best_exgolomb_count; @@ -172,8 +184,6 @@ void cc_code_latent_layer_bac( int layer_height, int layer_width, int hls_sig_blksize) { - printf("called cc_code_latent_layer_bac: file=%s\n", out_name.c_str()); - // get significant blocks. bool hls_sig_update = hls_sig_blksize < 0; if (hls_sig_update) @@ -245,7 +255,7 @@ void cc_code_latent_layer_bac( // want to bother? For significant blocks, we now say no. // SIG - printf("nz %d vs %d\n", n_zero, nby*nbx); + // printf("nz %d vs %d\n", n_zero, nby*nbx); //if (n_zero <= nby*nbx/20) if (1) // no longer use significance blocks. { @@ -259,7 +269,7 @@ void cc_code_latent_layer_bac( else { // signal block significance. - printf("sig1 %s\n", hls_sig_update ? "(update)" : "(noupdate)"); + // printf("sig1 %s\n", hls_sig_update ? "(update)" : "(noupdate)"); layer_BAC.encodeBinEP(1); auto ctx = BinProbModel_Std(PROBA_50_STATE); for (int by = 0; by < nby; by++) @@ -281,7 +291,7 @@ void cc_code_latent_layer_bac( // FLAT? - printf("nflat %d vs %d\n", n_flat, nby*nbx); + // printf("nflat %d vs %d\n", n_flat, nby*nbx); if (n_flat <= nby*nbx/20) { layer_BAC.encodeBinEP(0); @@ -292,7 +302,7 @@ void cc_code_latent_layer_bac( layer_BAC.encodeBinEP(1); // signal flat for sig blocks. auto ctx = BinProbModel_Std(PROBA_50_STATE); - printf("flat1\n"); + // printf("flat1\n"); for (int by = 0; by < nby; by++) { for (int bx = 0; bx < nbx; bx++) @@ -367,7 +377,7 @@ void cc_code_latent_layer_bac( exit(1); } fclose(fout); - printf("%s created\n", out_name.c_str()); + // printf("%s created\n", out_name.c_str()); delete[] blk_sig; delete[] blk_flat; diff --git a/coolchic/cpp/frame-memory.cpp b/coolchic/cpp/frame-memory.cpp index 9428c818..6861eb0e 100644 --- a/coolchic/cpp/frame-memory.cpp +++ b/coolchic/cpp/frame-memory.cpp @@ -1,24 +1,24 @@ #include "frame-memory.h" -void frame_memory::custom_pad_replicate_plane_in_place_i(int plane, int pad) +void frame_memory::custom_pad_replicate_plane_in_place_i(int plane, int padlr, int padtb) { int32_t *scan_in = plane_origin(plane); - int32_t *scan_out = scan_in - pad*stride-pad; - int w_in_padded = w+2*pad; + int32_t *scan_out = scan_in - padtb*stride-padlr; + int w_in_padded = w+2*padlr; // leading lines - for (int y = 0; y < pad; y++) + for (int y = 0; y < padtb; y++) { if (y == 0) { // construct first padded line. - for (int x = 0; x < pad; x++) + for (int x = 0; x < padlr; x++) { *scan_out++ = *scan_in; } memcpy(scan_out, scan_in, w*sizeof(scan_out[0])); scan_out += w; - for (int x = 0; x < pad; x++) + for (int x = 0; x < padlr; x++) { *scan_out++ = scan_in[w-1]; } @@ -31,31 +31,74 @@ void frame_memory::custom_pad_replicate_plane_in_place_i(int plane, int pad) scan_out += stride; } } + // internal lines: pad left and right edges. - for (int y = 0; y < h; y++) + if (padlr > 0) { - scan_out = scan_in-pad; - for (int x = 0; x < pad; x++) - { - *scan_out++ = *scan_in; - } - scan_out += w; - for (int x = 0; x < pad; x++) + for (int y = 0; y < h; y++) { - *scan_out++ = scan_in[w-1]; + scan_out = scan_in-padlr; + for (int x = 0; x < padlr; x++) + { + *scan_out++ = *scan_in; + } + scan_out += w; + for (int x = 0; x < padlr; x++) + { + *scan_out++ = scan_in[w-1]; + } + scan_in += stride; + scan_out += stride-w_in_padded; } - scan_in += stride; + } + else + { + scan_in += h*stride; + scan_out += h*stride; } // trailing lines -- we copy the previous padded. scan_in -= stride; // beginning of last line, padding to the left. - scan_out = scan_in+stride-pad; - for (int y = 0; y < pad; y++) + scan_out = scan_in+stride-padlr; + for (int y = 0; y < padtb; y++) { memcpy(scan_out, scan_out-stride, w_in_padded*sizeof(scan_out[0])); scan_out += stride; } } +void frame_memory::custom_pad_zero_plane_in_place_i(int plane, int padlr, int padtb) +{ + int32_t *scan_out = plane_origin(plane) - padtb*stride-padlr; + int w_in_padded = w+2*padlr; + // leading lines + for (int y = 0; y < padtb; y++) + { + memset(scan_out, 0, w_in_padded*sizeof(scan_out[0])); + scan_out += stride; + } + + // internal lines: pad left and right edges. + if (padlr > 0) + { + for (int y = 0; y < h; y++) + { + memset(scan_out, 0, padlr*sizeof(scan_out[0])); + memset(scan_out+padlr+w, 0, padlr*sizeof(scan_out[0])); + scan_out += stride; + } + } + else + { + scan_out += h*stride; + } + + for (int y = 0; y < padtb; y++) + { + memset(scan_out, 0, w_in_padded*sizeof(scan_out[0])); + scan_out += stride; + } +} + void frame_memory::zero_pad(int plane, int pad) { // !!! just the padding. diff --git a/coolchic/cpp/frame-memory.h b/coolchic/cpp/frame-memory.h index 0b55c1a3..f0308dea 100644 --- a/coolchic/cpp/frame-memory.h +++ b/coolchic/cpp/frame-memory.h @@ -7,6 +7,8 @@ struct frame_memory int origin_idx; int h; int w; + int pad; + int planes; int stride; int plane_stride; public: @@ -14,15 +16,19 @@ struct frame_memory origin_idx(0), h(0), w(0), + pad(0), + planes(0), stride(0), plane_stride(0) {}; ~frame_memory() {}; public: - void update_to(int frame_h, int frame_w, int frame_planes, int frame_pad) + void update_to(int frame_h, int frame_w, int frame_pad, int frame_planes) { h = frame_h; w = frame_w; + pad = frame_pad; + planes = frame_planes; stride = frame_pad+frame_w+frame_pad; plane_stride = (frame_pad+frame_h+frame_pad)*stride; origin_idx = frame_pad*stride+frame_pad; @@ -30,9 +36,10 @@ struct frame_memory } int32_t *raw() { return raw_frame.data; } int32_t *origin() { return raw_frame.data+origin_idx; } - int32_t *pad_origin(int pad) { return raw_frame.data+origin_idx - pad*stride - pad; } int32_t *plane_origin(int plane = 0) { return origin()+plane*plane_stride; } - int32_t *pad_origin(int plane, int pad) { return raw_frame.data + origin_idx-pad*stride-pad + plane*plane_stride; } + int32_t *plane_pixel(int plane, int y, int x) { return plane_origin(plane)+y*stride+x; } + int32_t *pad_origin(int plane, int pad) { return pad_origin(plane, pad, pad); } + int32_t *pad_origin(int plane, int padlr, int padtb) { return raw_frame.data + origin_idx-padtb*stride-padlr + plane*plane_stride; } void print_ranges(char const *msg, int nplanes, int precision) { int minval = plane_origin(0)[0]; @@ -48,9 +55,22 @@ struct frame_memory } printf("%s: minval=%g maxval=%g\n", msg == NULL ? "" : msg, (float)minval/(1< 0 ? n : h, n > 0 ? n : w, msg == NULL ? "" : msg); + for (int y = 0; y < (n > 0 ? n : h); y++) + { + for (int x = 0; x < (n > 0 ? n : w); x++) + printf(" %g", plane_origin(plane)[y*stride+x]/((1< - -#include "TDecBinCoderCABAC.h" -#include "cc-contexts.h" -#include "common.h" -#include "rest_avx2.h" - -// we have timing stuff here as well. -#include // timing. - -// FLOAT -float hsum_ps_sse3(__m128 v) { - __m128 shuf = _mm_movehdup_ps(v); // broadcast elements 3,1 to 2,0 - __m128 sums = _mm_add_ps(v, shuf); - shuf = _mm_movehl_ps(shuf, sums); // high half -> low half - sums = _mm_add_ss(sums, shuf); - return _mm_cvtss_f32(sums); -} - -float hsum256_ps_avx(__m256 v) { - __m128 vlow = _mm256_castps256_ps128(v); - __m128 vhigh = _mm256_extractf128_ps(v, 1); // high 128 - vlow = _mm_add_ps(vlow, vhigh); // add the low 128 - return hsum_ps_sse3(vlow); // and inline the sse3 version, which is optimal for AVX - // (no wasted instructions, and all of them are the 4B minimum) -} - -// INT -uint32_t hsum_epi32_avx(__m128i x) -{ - __m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a movdqa - __m128i sum64 = _mm_add_epi32(hi64, x); - __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); // Swap the low two elements - __m128i sum32 = _mm_add_epi32(sum64, hi32); - return _mm_cvtsi128_si32(sum32); // movd -} - -// only needs AVX2 -uint32_t hsum_8x32(__m256i v) -{ - __m128i sum128 = _mm_add_epi32( - _mm256_castsi256_si128(v), - _mm256_extracti128_si256(v, 1)); // silly GCC uses a longer AXV512VL instruction if AVX512 is enabled :/ - return hsum_epi32_avx(sum128); -} - -// we are applying 4 4x4 filters on a non-expanded source. -// -void custom_conv_ups_4x41_avx2(float *kw, int h_in, int w_in, float *in, float *out) -{ - int const ks = 4; - int const stride = 1; - int offs0 = 0; - - int32_t indexes[] = { 0, 1, 2, 3, w_in, w_in+1, w_in+2, w_in+3 }; - __m256i ind_intel = _mm256_loadu_si256((__m256i *)&indexes[0]); - - __m256 kernel_v0h0_top = _mm256_loadu_ps(kw+(0*2+0)*(ks*ks)+0); - __m256 kernel_v0h0_bot = _mm256_loadu_ps(kw+(0*2+0)*(ks*ks)+8); - __m256 kernel_v0h1_top = _mm256_loadu_ps(kw+(0*2+1)*(ks*ks)+0); - __m256 kernel_v0h1_bot = _mm256_loadu_ps(kw+(0*2+1)*(ks*ks)+8); - - __m256 kernel_v1h0_top = _mm256_loadu_ps(kw+(1*2+0)*(ks*ks)+0); - __m256 kernel_v1h0_bot = _mm256_loadu_ps(kw+(1*2+0)*(ks*ks)+8); - __m256 kernel_v1h1_top = _mm256_loadu_ps(kw+(1*2+1)*(ks*ks)+0); - __m256 kernel_v1h1_bot = _mm256_loadu_ps(kw+(1*2+1)*(ks*ks)+8); - - - for (int y = 0; y < h_in-ks+1; y += stride, offs0 += w_in) - { - int offs = offs0; - float *out_next = out+2*(w_in-ks+1); - - for (int x = 0; x < w_in-ks+1; x += stride, offs += stride) - { - __m256 in_0 = _mm256_i32gather_ps(&in[offs+0*w_in], ind_intel, sizeof(float)); // 8 32-bit floats - __m256 in_1 = _mm256_i32gather_ps(&in[offs+2*w_in], ind_intel, sizeof(float)); // 8 32-bit floats - - __m256 sum_0 = kernel_v0h0_top*in_0; - sum_0 += kernel_v0h0_bot*in_1; - __m256 sum_1 = kernel_v0h1_top*in_0; - sum_1 += kernel_v0h1_bot*in_1; - - *out++ = hsum256_ps_avx(sum_0); - *out++ = hsum256_ps_avx(sum_1); - - sum_0 = kernel_v1h0_top*in_0; - sum_0 += kernel_v1h0_bot*in_1; - sum_1 = kernel_v1h1_top*in_0; - sum_1 += kernel_v1h1_bot*in_1; - - *out_next++ = hsum256_ps_avx(sum_0); - *out_next++ = hsum256_ps_avx(sum_1); - } - - // we've done two output lines. - out = out_next; - } -} diff --git a/coolchic/cpp/rest_avx2.h b/coolchic/cpp/rest_avx2.h deleted file mode 100644 index 01e7db8a..00000000 --- a/coolchic/cpp/rest_avx2.h +++ /dev/null @@ -1,11 +0,0 @@ -/* - Software Name: Cool-Chic - SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange - SPDX-License-Identifier: BSD 3-Clause "New" - - This software is distributed under the BSD-3-Clause license. - Authors: see CONTRIBUTORS.md -*/ - - -void custom_conv_ups_4x41_avx2(float *kw, int h_in, int w_in, float *in, float *out); diff --git a/coolchic/cpp/syn_avx2.cpp b/coolchic/cpp/syn_avx2.cpp index afea11e5..276e9edc 100644 --- a/coolchic/cpp/syn_avx2.cpp +++ b/coolchic/cpp/syn_avx2.cpp @@ -99,6 +99,13 @@ #define SYN_NAME custom_conv_ksX_inX_outX_avx2 #include "syn_avx2.hpp" +#define SYN_NAME custom_conv_ks1_in7_hidden48_out3_avx2 +#define SYN_KS 1 +#define SYN_N_IN 7 +#define SYN_N_HIDDEN 48 +#define SYN_N_OUT 3 +#include "synfused_avx2.hpp" + #define SYN_NAME custom_conv_ks1_in7_hidden40_out3_avx2 #define SYN_KS 1 #define SYN_N_IN 7 @@ -120,6 +127,13 @@ #define SYN_N_OUT 9 #include "synfused_avx2.hpp" +#define SYN_NAME custom_conv_ks1_in7_hidden32_out3_avx2 +#define SYN_KS 1 +#define SYN_N_IN 7 +#define SYN_N_HIDDEN 32 +#define SYN_N_OUT 3 +#include "synfused_avx2.hpp" + #define SYN_NAME custom_conv_ks1_in7_hidden16_out3_avx2 #define SYN_KS 1 #define SYN_N_IN 7 @@ -161,3 +175,5 @@ #define SYN_NAME custom_conv_ks3_inX_outX_lb_avx2 #define SYN_KS 3 #include "synlb_avx2.hpp" + +#include "synblend_avx2.hpp" diff --git a/coolchic/cpp/syn_avx2.h b/coolchic/cpp/syn_avx2.h index 15e0cf3e..4b217a92 100644 --- a/coolchic/cpp/syn_avx2.h +++ b/coolchic/cpp/syn_avx2.h @@ -22,9 +22,11 @@ void custom_conv_ks1_inX_out9_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, i void custom_conv_ksX_inX_outX_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int residue, int relu); +void custom_conv_ks1_in7_hidden48_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out); void custom_conv_ks1_in7_hidden40_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out); void custom_conv_ks1_in7_hidden40_out6_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out); void custom_conv_ks1_in7_hidden40_out9_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out); +void custom_conv_ks1_in7_hidden32_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out); void custom_conv_ks1_in7_hidden16_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out); void custom_conv_ks1_in7_hidden8_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out); @@ -33,3 +35,6 @@ void custom_conv_ks3_in6_out6_lb_avx2(int KS, int32_t *kw, int32_t *kb, int h_in void custom_conv_ks3_in9_out6_lb_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu); void custom_conv_ks3_in9_out9_lb_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu); void custom_conv_ks3_inX_outX_lb_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu); + +void syn_blend1_avx2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out); +void syn_blend2_avx2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out, int32_t blend_val_out); diff --git a/coolchic/cpp/syn_avx2.hpp b/coolchic/cpp/syn_avx2.hpp index 8d8d5816..2f2c1886 100644 --- a/coolchic/cpp/syn_avx2.hpp +++ b/coolchic/cpp/syn_avx2.hpp @@ -61,21 +61,10 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i in_layer[0] = in; for (int i = 1; i < std::max(n_in, n_out); i++) in_layer[i] = in_layer[i-1]+plane_stride_in; -//#if SYN_KS == 1 -//// possible in-place for ks==1 -- we do not check here. -//#define out_layer in_layer -// if (out != NULL && out != in) -// { -// printf("%s: ks=%d n_in=%d n_out=%d: bad call: should be in-place, but out supplied\n", xstr(SYN_NAME), KS, N_IN, N_OUT); -// exit(1); -// } -//#else - //not in-place int32_t *out_layer[N_OUT]; out_layer[0] = out; for (int i = 1; i < N_OUT; i++) out_layer[i] = out_layer[i-1]+plane_stride_in; -//#endif const __m256i z = _mm256_setzero_si256(); @@ -159,32 +148,44 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i } } } + int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size; + int32_t store[8]; if (relu) { + // -ves to zero, >> for rest. for (int ol = 0; ol < atatime; ol++) { out_avx2[ol] = _mm256_blendv_epi8(z, out_avx2[ol], _mm256_cmpgt_epi32(out_avx2[ol], z)); - } - } - if (x_blk < xlim_blk-1) - { - for (int ol = 0; ol < atatime; ol++) - { out_avx2[ol] = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION); - _mm256_storeu_si256((__m256i_u*)&out_layer[olbase+ol][offso], out_avx2[ol]); + if (n_outs == 8) + { + _mm256_storeu_si256((__m256i_u*)&out_layer[olbase+ol][offso], out_avx2[ol]); + } + else + { + _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]); + memcpy(&out_layer[olbase+ol][offso], &store[0], n_outs*sizeof(store[0])); + } } } else { - int n_outs = xlast_blk_size; - // partial last line. + // need different treatment for -ve and +ve. for (int ol = 0; ol < atatime; ol++) { - int32_t store[8]; - out_avx2[ol] = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION); - _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]); - memcpy(&out_layer[olbase+ol][offso], &store[0], n_outs*sizeof(store[0])); + __m256i sr = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION); + __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_avx2[ol]), SYN_MUL_PRECISION)); + out_avx2[ol] = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_avx2[ol], z)); + if (n_outs == 8) + { + _mm256_storeu_si256((__m256i_u*)&out_layer[olbase+ol][offso], out_avx2[ol]); + } + else + { + _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]); + memcpy(&out_layer[olbase+ol][offso], &store[0], n_outs*sizeof(store[0])); + } } } } // olbase diff --git a/coolchic/cpp/syn_cpu.cpp b/coolchic/cpp/syn_cpu.cpp index 7169ae33..d29117a9 100644 --- a/coolchic/cpp/syn_cpu.cpp +++ b/coolchic/cpp/syn_cpu.cpp @@ -10,7 +10,6 @@ #include #include -#include // !!! abs #include #include "common.h" @@ -30,3 +29,5 @@ #define SYN_NAME custom_conv_ks3_inX_outX_lb #define SYN_KS 3 #include "synlb_cpu.hpp" + +#include "synblend_cpu.hpp" diff --git a/coolchic/cpp/syn_cpu.h b/coolchic/cpp/syn_cpu.h index 6998fdd7..526c1210 100644 --- a/coolchic/cpp/syn_cpu.h +++ b/coolchic/cpp/syn_cpu.h @@ -13,3 +13,6 @@ void custom_conv_ksX_inX_outX(int KS, int32_t *kw, int32_t *kb, int h_in, int w_ void custom_conv_ks1_inX_hiddenX_outX(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out); void custom_conv_ks3_inX_outX_lb(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu); + +void syn_blend1(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out); +void syn_blend2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out, int32_t blend_val_out); diff --git a/coolchic/cpp/syn_cpu.hpp b/coolchic/cpp/syn_cpu.hpp index 56e55bb6..4ca52901 100644 --- a/coolchic/cpp/syn_cpu.hpp +++ b/coolchic/cpp/syn_cpu.hpp @@ -20,7 +20,7 @@ // stride and plane_stride are assumed the same for in and out. void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int residue, int relu) { - printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu); + //printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu); int const kstride = 1; #ifdef SYN_KS @@ -90,9 +90,16 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i sum += xxres; } } - sum >>= SYN_MUL_PRECISION; // take multiplied sum to output. // !!! check sign? - if (relu && sum < 0) - sum = 0; + // take multiplied sum to output after reluing. + if (sum < 0) + { + if (relu) + sum = 0; + else + sum = -(-sum >> SYN_MUL_PRECISION); + } + else + sum >>= SYN_MUL_PRECISION; out_cache[ol] = sum; } // flush. diff --git a/coolchic/cpp/synblend_avx2.hpp b/coolchic/cpp/synblend_avx2.hpp new file mode 100644 index 00000000..f428ce05 --- /dev/null +++ b/coolchic/cpp/synblend_avx2.hpp @@ -0,0 +1,12 @@ + +void syn_blend1_avx2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val, int32_t *out) +{ + printf("should not be here\n"); + exit(1); +} + +void syn_blend2_avx2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out, int32_t blend_val_out) +{ + printf("should not be here\n"); + exit(1); +} diff --git a/coolchic/cpp/synblend_cpu.hpp b/coolchic/cpp/synblend_cpu.hpp new file mode 100644 index 00000000..21d4dd85 --- /dev/null +++ b/coolchic/cpp/synblend_cpu.hpp @@ -0,0 +1,50 @@ + +// out = out + in*blend_val +void syn_blend1(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val, int32_t *out) +{ + for (int p = 0; p < N_INOUT; p++) + { + int32_t *src = in+plane_stride_in*p; + int32_t *dst = out+plane_stride_in*p; + for (int y = 0; y < h_in; y++, src += stride_in-w_in, dst += stride_in-w_in) + { + for (int x = 0; x < w_in; x++, src++, dst++) + { + int x0 = *src; + if (x0 < 0) + x0 = 0; + else if (x0 > (1<> SYN_LAYER_PRECISION); + } + } + } +} + +// out = out*blend_val_out + in*blend_val_in +void syn_blend2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out, int32_t blend_val_out) +{ + for (int p = 0; p < N_INOUT; p++) + { + int32_t *src = in+plane_stride_in*p; + int32_t *dst = out+plane_stride_in*p; + for (int y = 0; y < h_in; y++, src += stride_in-w_in, dst += stride_in-w_in) + { + for (int x = 0; x < w_in; x++, src++, dst++) + { + int x0 = *src; + if (x0 < 0) + x0 = 0; + else if (x0 > (1< (1<> SYN_LAYER_PRECISION; + } + } + } +} diff --git a/coolchic/cpp/synfused_avx2.hpp b/coolchic/cpp/synfused_avx2.hpp index 2b602086..d22d8e6b 100644 --- a/coolchic/cpp/synfused_avx2.hpp +++ b/coolchic/cpp/synfused_avx2.hpp @@ -13,12 +13,12 @@ // N_IN (7) // N_HIDDEN (16 or 40) // N_OUT (3) -// HIDDEN always RELU, OUT always NONE +// HIDDEN always RELU void SYN_NAME(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int stride_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out) { - printf("%s(ks=%d N_IN=%d N_HIDDEN=%d N_OUT=%d\n", xstr(SYN_NAME), KS, N_IN, N_HIDDEN, N_OUT); + //printf("%s(ks=%d N_IN=%d N_HIDDEN=%d N_OUT=%d\n", xstr(SYN_NAME), KS, N_IN, N_HIDDEN, N_OUT); #ifdef SYN_KS int const ks = SYN_KS; @@ -46,6 +46,7 @@ void SYN_NAME(int KS, int const n_hidden8 = n_hidden/8; int32_t *src = in; + int32_t *dst = out; if (KS != ks) { @@ -71,8 +72,8 @@ void SYN_NAME(int KS, const __m256i rotate_right = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 0); const __m256i z = _mm256_setzero_si256(); - for (int y = 0; y < h_in; y++, src += src_pad+src_pad) // eol of this, and bol of next. - for (int x = 0; x < w_in; x++, src++) + for (int y = 0; y < h_in; y++, src += src_pad+src_pad, dst += src_pad+src_pad) // eol of this, and bol of next. + for (int x = 0; x < w_in; x++, src++, dst++) { __m256i input_avx2_src = _mm256_i32gather_epi32(&src[0], ind_intel_0, sizeof(int32_t)); __m256i hidden_avx2[n_hidden/8]; @@ -123,8 +124,11 @@ void SYN_NAME(int KS, // horizontal sum int32_t sum = kb[ol] + hsum_8x32(out_avx2); - sum >>= SYN_MUL_PRECISION; // take multiplied sum to output. // !!! check sign? - src[ol*plane_stride_in] = sum; + if (sum < 0) + sum = -(-sum >> SYN_MUL_PRECISION); + else + sum >>= SYN_MUL_PRECISION; + dst[ol*plane_stride_in] = sum; } } // x, y } diff --git a/coolchic/cpp/synfused_cpu.hpp b/coolchic/cpp/synfused_cpu.hpp index 2c90932d..849b0a8b 100644 --- a/coolchic/cpp/synfused_cpu.hpp +++ b/coolchic/cpp/synfused_cpu.hpp @@ -13,12 +13,12 @@ // N_IN (7) // N_HIDDEN (16 or 40) // N_OUT (3) -// HIDDEN always RELU, OUT always NONE +// HIDDEN always RELU void SYN_NAME(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int stride_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out) { - printf("%s(ks=%d N_IN=%d N_HIDDEN=%d N_OUT=%d\n", xstr(SYN_NAME), KS, N_IN, N_HIDDEN, N_OUT); + //printf("%s(ks=%d N_IN=%d N_HIDDEN=%d N_OUT=%d\n", xstr(SYN_NAME), KS, N_IN, N_HIDDEN, N_OUT); #ifdef SYN_KS int const ks = SYN_KS; @@ -40,7 +40,6 @@ void SYN_NAME(int KS, #else int const n_out = N_OUT; #endif - int const pad_in = (stride_in-w_in)/2; if (KS != ks) { @@ -48,11 +47,12 @@ void SYN_NAME(int KS, exit(1); } int32_t *src = in; + int32_t *dst = out; int32_t elements_7[n_in]; // in int32_t elements_40[n_hidden]; // hidden - for (int y = 0; y < h_in; y++, src += pad_in+pad_in) // pads are: eol of this, and bol of next. - for (int x = 0; x < w_in; x++, src++) + for (int y = 0; y < h_in; y++, src += stride_in-w_in, dst += stride_in-w_in) + for (int x = 0; x < w_in; x++, src++, dst++) { int32_t *inputs = &elements_7[0]; int32_t *hidden = &elements_40[0]; @@ -98,8 +98,12 @@ void SYN_NAME(int KS, for (int il = 0; il < n_hidden; il++) sum += hidden[il]*kw[il]; // no relu - sum >>= SYN_MUL_PRECISION; // take multiplied sum to output. // !!! check sign? - src[ol*plane_stride_in] = sum; + // take multiplied sum to output. + if (sum < 0) + sum = -(-sum >> SYN_MUL_PRECISION); + else + sum >>= SYN_MUL_PRECISION; + dst[ol*plane_stride_in] = sum; } } // x, y } diff --git a/coolchic/cpp/synlb_avx2.hpp b/coolchic/cpp/synlb_avx2.hpp index 8423de88..ac3be67f 100644 --- a/coolchic/cpp/synlb_avx2.hpp +++ b/coolchic/cpp/synlb_avx2.hpp @@ -55,13 +55,17 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i #endif #endif - int32_t *in_layer[std::max(n_in, n_out)]; + int32_t *in_layer[n_in]; in_layer[0] = in; - for (int i = 1; i < std::max(n_in, n_out); i++) + for (int i = 1; i < n_in; i++) in_layer[i] = in_layer[i-1]+plane_stride_in; + int32_t *out_layer[n_out]; + out_layer[0] = out; + for (int i = 1; i < n_out; i++) + out_layer[i] = out_layer[i-1]+plane_stride_in; + // in-place, must have line buffer. - int32_t **out_layer; int32_t *lb[2]; // two line buffer pointers. int h_out = h_in-ks+1; int w_out = w_in-ks+1; @@ -70,23 +74,14 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i printf("%s: bad call: in-place lb must have ks=3\n", xstr(SYN_NAME)); exit(1); } - if (out == NULL || out == in) - { - // must have a line buffer. - if (line_buffer == NULL) - { - printf("%s: bad call, no line buffer supplied\n", xstr(SYN_NAME)); - exit(1); - } - out_layer = in_layer; - lb[0] = line_buffer; - lb[1] = line_buffer+w_out*n_out; - } - else + // must have a line buffer. + if (line_buffer == NULL) { - printf("%s: bad call should have lb and in-place\n", xstr(SYN_NAME)); + printf("%s: bad call, no line buffer supplied\n", xstr(SYN_NAME)); exit(1); } + lb[0] = line_buffer; + lb[1] = line_buffer+w_out*n_out; const __m256i z = _mm256_setzero_si256(); @@ -170,32 +165,43 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i } } } + int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size; + int32_t store[8]; if (relu) { + // -ves to zero, >> for rest. for (int ol = 0; ol < atatime; ol++) { out_avx2[ol] = _mm256_blendv_epi8(z, out_avx2[ol], _mm256_cmpgt_epi32(out_avx2[ol], z)); - } - } - if (x_blk < xlim_blk-1) - { - for (int ol = 0; ol < atatime; ol++) - { out_avx2[ol] = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION); - //_mm256_storeu_si256((__m256i_u*)&out_layer[olbase+ol][offso], out_avx2[ol]); - _mm256_storeu_si256((__m256i_u*)&lb[y%2][(olbase+ol)*w_out+x_blk*8], out_avx2[ol]); + if (n_outs == 8) + { + _mm256_storeu_si256((__m256i_u*)&lb[y%2][(olbase+ol)*w_out+x_blk*8], out_avx2[ol]); + } + else + { + _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]); + memcpy(&lb[y%2][(olbase+ol)*w_out+x_blk*8], &store[0], n_outs*sizeof(store[0])); + } } } else { - int n_outs = xlast_blk_size; - // partial last line. + // need different treatment for -ve and +ve. for (int ol = 0; ol < atatime; ol++) { - int32_t store[8]; - out_avx2[ol] = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION); - _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]); - memcpy(&lb[y%2][(olbase+ol)*w_out+x_blk*8], &store[0], n_outs*sizeof(store[0])); + __m256i sr = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION); + __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_avx2[ol]), SYN_MUL_PRECISION)); + out_avx2[ol] = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_avx2[ol], z)); + if (n_outs == 8) + { + _mm256_storeu_si256((__m256i_u*)&lb[y%2][(olbase+ol)*w_out+x_blk*8], out_avx2[ol]); + } + else + { + _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]); + memcpy(&lb[y%2][(olbase+ol)*w_out+x_blk*8], &store[0], n_outs*sizeof(store[0])); + } } } } // olbase @@ -205,12 +211,12 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i if (y >= 1) { for (int ol = 0; ol < n_out; ol++) - memcpy(&out_layer[ol][offsi_base-stride_in+residue_origin_offset], &lb[(y-1)%2][ol*w_out], w_out*sizeof(int32_t)); + memcpy(&out_layer[ol][offsi_base-stride_in], &lb[(y-1)%2][ol*w_out], w_out*sizeof(int32_t)); } } // y // flush final line. for (int ol = 0; ol < n_out; ol++) - memcpy(&out_layer[ol][offsi_base-stride_in+residue_origin_offset], &lb[(h_out-1)%2][ol*w_out], w_out*sizeof(int32_t)); + memcpy(&out_layer[ol][offsi_base-stride_in], &lb[(h_out-1)%2][ol*w_out], w_out*sizeof(int32_t)); } #undef tostr diff --git a/coolchic/cpp/synlb_cpu.hpp b/coolchic/cpp/synlb_cpu.hpp index 7fb53794..4092808e 100644 --- a/coolchic/cpp/synlb_cpu.hpp +++ b/coolchic/cpp/synlb_cpu.hpp @@ -21,7 +21,7 @@ // dedicated to 3x3 kernels that use a line-buffer for temporary storage, allowing in-place convolution. void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu) { - printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu); + //printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu); int const kstride = 1; #ifdef SYN_KS @@ -40,12 +40,17 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i int const n_out = N_OUT; #endif - int32_t *in_layer[std::max(n_in, n_out)]; + int32_t *in_layer[n_in]; in_layer[0] = in; - for (int i = 1; i < std::max(n_in, n_out); i++) + for (int i = 1; i < n_in; i++) in_layer[i] = in_layer[i-1]+plane_stride_in; + + int32_t *out_layer[n_out]; + out_layer[0] = out; + for (int i = 1; i < n_out; i++) + out_layer[i] = out_layer[i-1]+plane_stride_in; + // in-place, must have line buffer. - int32_t **out_layer; int32_t *lb[2]; // two line buffer pointers. int h_out = h_in-ks+1; int w_out = w_in-ks+1; @@ -54,23 +59,14 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i printf("%s: bad call: in-place lb must have ks=3\n", xstr(SYN_NAME)); exit(1); } - if (out == NULL || out == in) - { - // must have a line buffer. - if (line_buffer == NULL) - { - printf("%s: bad call, no line buffer supplied\n", xstr(SYN_NAME)); - exit(1); - } - out_layer = in_layer; - lb[0] = line_buffer; - lb[1] = line_buffer+w_out*n_out; - } - else + // must have a line buffer. + if (line_buffer == NULL) { - printf("%s: bad call should have lb and in-place\n", xstr(SYN_NAME)); + printf("%s: bad call, no line buffer supplied\n", xstr(SYN_NAME)); exit(1); } + lb[0] = line_buffer; + lb[1] = line_buffer+w_out*n_out; // here we collect the output during processing, and flush later. @@ -102,9 +98,16 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i sum += xxres; } } - sum >>= SYN_MUL_PRECISION; // take multiplied sum to output. // !!! check sign? - if (relu && sum < 0) - sum = 0; + // take multiplied sum to output after reluing. + if (sum < 0) + { + if (relu) + sum = 0; + else + sum = -(-sum >> SYN_MUL_PRECISION); + } + else + sum >>= SYN_MUL_PRECISION; lb[y%2][ol*w_out+x] = sum; } } @@ -112,12 +115,12 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i if (y >= 1) { for (int ol = 0; ol < n_out; ol++) - memcpy(&out_layer[ol][offs0-stride_in+residue_origin_offset], &lb[(y-1)%2][ol*w_out], w_out*sizeof(int32_t)); + memcpy(&out_layer[ol][offs0-stride_in], &lb[(y-1)%2][ol*w_out], w_out*sizeof(int32_t)); } } // flush final line. for (int ol = 0; ol < n_out; ol++) - memcpy(&out_layer[ol][offs0-stride_in+residue_origin_offset], &lb[(h_out-1)%2][ol*w_out], w_out*sizeof(int32_t)); + memcpy(&out_layer[ol][offs0-stride_in], &lb[(h_out-1)%2][ol*w_out], w_out*sizeof(int32_t)); } #undef tostr diff --git a/coolchic/cpp/ups_avx2.cpp b/coolchic/cpp/ups_avx2.cpp index 9f98b2c5..5646300f 100644 --- a/coolchic/cpp/ups_avx2.cpp +++ b/coolchic/cpp/ups_avx2.cpp @@ -1,35 +1,25 @@ -/* - Software Name: Cool-Chic - SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange - SPDX-License-Identifier: BSD 3-Clause "New" - This software is distributed under the BSD-3-Clause license. - Authors: see CONTRIBUTORS.md -*/ - - -#include -#include -#include -#include #include "common.h" +#include "frame-memory.h" +#include "ups_cpu.h" +#include + +#define KS 7 +#define UPSNAME ups_refine_ks7_avx2 +#include "ups_refine_avx2.hpp" -#define SYN_NAME ups_4x4x4_fromups_avx2 -#define SYN_KS 4 -#define UPS_MUL_PRECISION UPS_PRECISION -#include "ups_avx2.hpp" +#define UPSNAME ups_refine_ksX_avx2 +#include "ups_refine_avx2.hpp" -#define SYN_NAME ups_4x2x2_fromups_avx2 -#define SYN_KS 2 -#define UPS_MUL_PRECISION UPS_PRECISION -#include "ups_avx2.hpp" +#define KS 8 +#define UPS_SRC_PRECISION ARM_PRECISION +#define UPSNAME ups_upsample_ks8_ARMPREC_avx2 +#include "ups_upsample_avx2.hpp" -#define SYN_NAME ups_4x4x4_fromarm_avx2 -#define SYN_KS 4 -#define UPS_MUL_PRECISION ARM_PRECISION -#include "ups_avx2.hpp" +#define KS 8 +#define UPS_SRC_PRECISION UPS_PRECISION +#define UPSNAME ups_upsample_ks8_UPSPREC_avx2 +#include "ups_upsample_avx2.hpp" -#define SYN_NAME ups_4x2x2_fromarm_avx2 -#define SYN_KS 2 -#define UPS_MUL_PRECISION ARM_PRECISION -#include "ups_avx2.hpp" +#define UPSNAME ups_upsample_ksX_avx2 +#include "ups_upsample_avx2.hpp" diff --git a/coolchic/cpp/ups_avx2.h b/coolchic/cpp/ups_avx2.h index 87d29593..6d9a656e 100644 --- a/coolchic/cpp/ups_avx2.h +++ b/coolchic/cpp/ups_avx2.h @@ -8,7 +8,8 @@ */ -void ups_4x4x4_fromarm_avx2(int KS, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int stride_out, int32_t *out); -void ups_4x4x4_fromups_avx2(int KS, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int stride_out, int32_t *out); -void ups_4x2x2_fromarm_avx2(int KS, int32_t *kw, int h_in, int w_in, int stride_int, int32_t *in, int stride_out, int32_t *out); -void ups_4x2x2_fromups_avx2(int KS, int32_t *kw, int h_in, int w_in, int stride_int, int32_t *in, int stride_out, int32_t *out); +void ups_refine_ks7_avx2(int ks, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp); +void ups_refine_ksX_avx2(int ks, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp); +void ups_upsample_ks8_UPSPREC_avx2(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp); +void ups_upsample_ks8_ARMPREC_avx2(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp); +void ups_upsample_ksX_avx2(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp); diff --git a/coolchic/cpp/ups_avx2.hpp b/coolchic/cpp/ups_avx2.hpp deleted file mode 100644 index 057f3404..00000000 --- a/coolchic/cpp/ups_avx2.hpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - Software Name: Cool-Chic - SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange - SPDX-License-Identifier: BSD 3-Clause "New" - - This software is distributed under the BSD-3-Clause license. - Authors: see CONTRIBUTORS.md -*/ - - -#define tostr(x) #x -#define xstr(x) tostr(x) - -// kw: weights for 4 4x4 kernels. -void SYN_NAME(int KS, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int stride_out, int32_t *out) -{ - //printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d, relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu); - int const kstride = 1; -#ifdef SYN_KS - int const ks = SYN_KS; -#else - int const ks = KS; -#endif - int const n_out = 4; // number of kernels, actually, always 4. - - if (ks != KS) - { - printf("%s: ks=%d: bad call\n", xstr(SYN_NAME), KS); - exit(1); - } - - int const ks2 = ks*ks; - - // we generate 8 outputs at a time, per kernel. - // these outputs are emitted onto two output lines. - // we advance 8 pixels in the input after that. - int offsi_base = 0; - int offso_base = 0; - // the last block in x might not be full. - - int xlim_blk = ((w_in-ks+1)+7)/8; // pixels in number of horizontal blocks - int xlast_blk_size = 8-(xlim_blk*8 - (w_in-ks+1)); // number of pixels to emit per filter in last block. - for (int y = 0; y < h_in-ks+1; y += kstride, offsi_base += stride_in, offso_base += 2*stride_out) - { - int offsi_line = offsi_base; - int offso = offso_base; - for (int x_blk = 0; x_blk < xlim_blk; x_blk += kstride, offsi_line += kstride*8) - { - __m256i out_avx2[n_out]; - for (int ol = 0; ol < n_out; ol++) - { - out_avx2[ol] = _mm256_setzero_si256(); - } - int offsi_block = offsi_line; - int koffs = 0; - for (int yy = 0; yy < ks; yy++, offsi_block += stride_in-ks) - { - for (int xx = 0; xx < ks; xx++, offsi_block++, koffs++) - { - __m256i input = _mm256_loadu_si256((__m256i_u*)&in[offsi_block]); - for (int ol = 0; ol < n_out; ol++) - { - __m256i kk = _mm256_set1_epi32(kw[ol*ks2+koffs]); - __m256i mul = _mm256_mullo_epi32(input, kk); - out_avx2[ol] = _mm256_add_epi32(out_avx2[ol], mul); - } - } - } - // merge the outputs into the destination! - // We have: - // ol0 A0 A1 A2 A3.. A7 - // ol1 B0 B1 B2 B3.. B7 - // ol2 C0 C1 C2 C3.. C7 - // ol3 D0 D1 D2 D3.. D7 - // - // to go to 2 lines as: - // A0 B0 A1 B1 A2 B2.. A7 B7 - // C0 D0 C1 D1 C2 D2.. C7 D7 - - out_avx2[0] = _mm256_srai_epi32(out_avx2[0], UPS_MUL_PRECISION); - out_avx2[1] = _mm256_srai_epi32(out_avx2[1], UPS_MUL_PRECISION); - out_avx2[2] = _mm256_srai_epi32(out_avx2[2], UPS_MUL_PRECISION); - out_avx2[3] = _mm256_srai_epi32(out_avx2[3], UPS_MUL_PRECISION); - - // spread horizontally and over two lines. - //float outputs[4][8]; - __m256i outA0 = _mm256_unpacklo_epi32(out_avx2[0], out_avx2[1]); - __m256i outA1 = _mm256_unpackhi_epi32(out_avx2[0], out_avx2[1]); - __m256i outC0 = _mm256_unpacklo_epi32(out_avx2[2], out_avx2[3]); - __m256i outC1 = _mm256_unpackhi_epi32(out_avx2[2], out_avx2[3]); - - // We need to remangle A0,A1 and C0,C1, switching their high and low 128-bit halves. - __m256i outA00 = _mm256_permute2f128_si256(outA0, outA1, 0x20); - __m256i outA01 = _mm256_permute2f128_si256(outA0, outA1, 0x31); - __m256i outC00 = _mm256_permute2f128_si256(outC0, outC1, 0x20); - __m256i outC01 = _mm256_permute2f128_si256(outC0, outC1, 0x31); - if (x_blk < xlim_blk-1) - { - _mm256_storeu_si256((__m256i_u*)&out[offso+0], outA00); - _mm256_storeu_si256((__m256i_u*)&out[offso+8], outA01); - _mm256_storeu_si256((__m256i_u*)&out[offso+stride_out+0], outC00); - _mm256_storeu_si256((__m256i_u*)&out[offso+stride_out+8], outC01); - offso += 2*8; - } - else - { - int n_outs = 2*xlast_blk_size; // note -- doubled. - int32_t tmp[8]; - for (int line = 0; line < 2; line++) - { - if (n_outs >= 8) - { - _mm256_storeu_si256((__m256i_u*)&out[offso+line*stride_out+0], outA00); - if (n_outs >= 16) - _mm256_storeu_si256((__m256i_u*)&out[offso+line*stride_out+8], outA01); - else if (n_outs > 8) - { - // remainder over 8. - _mm256_storeu_si256((__m256i_u*)&tmp[0], outA01); - memcpy(&out[offso+line*stride_out+8], &tmp[0], (n_outs-8)*sizeof(tmp[0])); - } - } - else - { - // remainder over 0. - _mm256_storeu_si256((__m256i_u*)&tmp[0], outA00); - memcpy(&out[offso+line*stride_out+0], &tmp[0], n_outs*sizeof(tmp[0])); - } - outA00 = outC00; - outA01 = outC01; - } - offso += n_outs; - } - } - } -} - -#undef tostr -#undef xstr -#undef SYN_NAME -#undef SYN_KS -#undef UPS_MUL_PRECISION diff --git a/coolchic/cpp/ups_cpu.cpp b/coolchic/cpp/ups_cpu.cpp new file mode 100644 index 00000000..df012827 --- /dev/null +++ b/coolchic/cpp/ups_cpu.cpp @@ -0,0 +1,18 @@ + +#include "common.h" +#include "frame-memory.h" +#include "ups_cpu.h" + +#define KS 7 +#define UPSNAME ups_refine_ks7_cpu +#include "ups_refine_cpu.hpp" + +#define UPSNAME ups_refine_ksX_cpu +#include "ups_refine_cpu.hpp" + +#define KS 8 +#define UPSNAME ups_upsample_ks8_cpu +#include "ups_upsample_cpu.hpp" + +#define UPSNAME ups_upsample_ksX_cpu +#include "ups_upsample_cpu.hpp" diff --git a/coolchic/cpp/ups_cpu.h b/coolchic/cpp/ups_cpu.h new file mode 100644 index 00000000..dcbc6dc0 --- /dev/null +++ b/coolchic/cpp/ups_cpu.h @@ -0,0 +1,7 @@ + + +void ups_refine_ks7_cpu(int ks, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp); +void ups_refine_ksX_cpu(int ks, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp); + +void ups_upsample_ks8_cpu(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp); +void ups_upsample_ksX_cpu(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp); diff --git a/coolchic/cpp/ups_refine_avx2.hpp b/coolchic/cpp/ups_refine_avx2.hpp new file mode 100644 index 00000000..ed75da04 --- /dev/null +++ b/coolchic/cpp/ups_refine_avx2.hpp @@ -0,0 +1,135 @@ + +// defines expected: +// KS +// UPSNAME + +#define tostr(x) #x +#define xstr(x) tostr(x) + +// ups_src_precision is always ARM_PRECISION (from arm src) +void UPSNAME(int ks_param, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp) +{ +#ifdef KS + int const ks = KS; + if (ks != ks_param) + { + printf("%s: bad call, ks_param %d\n", xstr(UPSNAME), ks_param); + exit(1); + } + + // kernel in registers + __m256i KW[KS]; + for (int i = 0; i < KS; i++) + KW[i] = _mm256_set1_epi32(kw[i]); +#else + int const ks = ks_param; +#endif + // temporary check. + if (ups_src_precision != ARM_PRECISION) + { + printf("%s: bad call: input precision is %d, not %d\n", xstr(UPSNAME), ups_src_precision, ARM_PRECISION); + exit(1); + } + int const UPS_SRC_PRECISION = ARM_PRECISION; + int kshalf = ks/2; + int pad = kshalf; + // prepare output. + out.update_to(in.h, in.w, out.pad, 1); + // pad input. + in.custom_pad_zero_plane_in_place_i(0, pad, 0); // LR pad. + + // prepare temporary to hold h x w + tmp.update_to(in.h, in.w, tmp.pad, tmp.planes); + int32_t *src = in.pad_origin(0, pad, 0); + int32_t *dst = tmp.origin(); + const __m256i z = _mm256_setzero_si256(); + + // h (horizontal treatment) to temporary. + int xlim_blk = (in.w+7)/8; // pixels in number of horizontal blocks + int xlast_blk_size = 8-(xlim_blk*8 - in.w); // number of pixels to emit per filter in last block. + int32_t store[8]; + + for (int y = 0; y < in.h; y++, src += in.stride-xlim_blk*8, dst += tmp.stride-xlim_blk*8) + { + for (int x_blk = 0; x_blk < xlim_blk; x_blk++, src += 8, dst += 8) + { + int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size; + __m256i out_avx2 = _mm256_setzero_si256(); + + for (int xx = 0; xx < ks; xx++) + { + __m256i input = _mm256_loadu_si256((__m256i_u*)&src[xx]); +#ifdef KS + __m256i mul = _mm256_mullo_epi32(input, KW[xx]); +#else + __m256i kk = _mm256_set1_epi32(kw[xx]); + __m256i mul = _mm256_mullo_epi32(input, kk); +#endif + out_avx2 = _mm256_add_epi32(out_avx2, mul); + } + // need different treatment for -ve -(-(>>)) and +ve (>>). + __m256i sr = _mm256_srai_epi32(out_avx2, ARM_PRECISION); + __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_avx2), ARM_PRECISION)); + out_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_avx2, z)); + + if (n_outs == 8) + { + _mm256_storeu_si256((__m256i_u*)dst, out_avx2); + } + else + { + _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2); + memcpy(dst, &store[0], n_outs*sizeof(store[0])); + } + } + } + + // v (vertical treatment) to output. + tmp.custom_pad_zero_plane_in_place_i(0, 0, pad); // TB pad. + int32_t *hsrc = tmp.pad_origin(0, 0, pad); + src = in.origin(); + dst = out.origin(); + int residue_shift = UPS_PRECISION-UPS_SRC_PRECISION; // src is arm output at lower precision. + + for (int y = 0; y < in.h; y++, hsrc += tmp.stride-xlim_blk*8, src += in.stride-xlim_blk*8, dst += out.stride-xlim_blk*8) + { + for (int x_blk = 0; x_blk < xlim_blk; x_blk++, hsrc += 8, src += 8, dst += 8) + { + int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size; + __m256i out_avx2 = _mm256_loadu_si256((__m256i_u*)&src[0]); + out_avx2 = _mm256_slli_epi32(out_avx2, residue_shift+UPS_PRECISION); + + for (int yy = 0; yy < ks; yy++) + { + __m256i input = _mm256_loadu_si256((__m256i_u*)&hsrc[yy*tmp.stride]); +#ifdef KS + __m256i mul = _mm256_mullo_epi32(input, KW[yy]); +#else + __m256i kk = _mm256_set1_epi32(kw[yy]); + __m256i mul = _mm256_mullo_epi32(input, kk); +#endif + out_avx2 = _mm256_add_epi32(out_avx2, mul); + } + + // need different treatment for -ve -(-(>>)) and +ve (>>). + __m256i sr = _mm256_srai_epi32(out_avx2, UPS_PRECISION); + __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_avx2), UPS_PRECISION)); + out_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_avx2, z)); + + if (n_outs == 8) + { + _mm256_storeu_si256((__m256i_u*)dst, out_avx2); + } + else + { + _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2); + memcpy(dst, &store[0], n_outs*sizeof(store[0])); + } + } + } +} + +#undef KS +#undef UPSNAME +#undef tostr +#undef xstr diff --git a/coolchic/cpp/ups_refine_cpu.hpp b/coolchic/cpp/ups_refine_cpu.hpp new file mode 100644 index 00000000..47c0d4b8 --- /dev/null +++ b/coolchic/cpp/ups_refine_cpu.hpp @@ -0,0 +1,84 @@ + +// defines expected: +// KS +// UPSNAME + +#define tostr(x) #x +#define xstr(x) tostr(x) + +// ups_src_precision is always ARM_PRECISION (from arm src) +// tmp is guaranteed to hold a horizontally upsampled in. +void UPSNAME(int ks_param, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp) +{ +#ifdef KS + int const ks = KS; + if (ks != ks_param) + { + printf("%s: bad call, ks_param %d\n", xstr(UPSNAME), ks_param); + exit(1); + } +#else + int const ks = ks_param; +#endif + int kshalf = ks/2; + int pad = kshalf; + // prepare output. + out.update_to(in.h, in.w, out.pad, 1); + // pad input. + in.custom_pad_zero_plane_in_place_i(0, pad, 0); // LR pad. + + // prepare temporary to hold h x w + tmp.update_to(in.h, in.w, tmp.pad, tmp.planes); + int32_t *src = in.pad_origin(0, pad, 0); + int32_t *dst = tmp.origin(); + + // h (horizontal treatment) to temporary. + for (int y = 0; y < in.h; y++, src += in.stride-in.w, dst += tmp.stride-in.w) + { + for (int x = 0; x < in.w; x++, src++, dst++) + { +#if 1 + // ignore symmetric. + int sum = 0; + for (int xx = 0; xx < ks; xx++) + sum += src[xx]*kw[xx]; +#endif +#if 0 + // use symmetric. + int sum = src[kshalf]*kw[kshalf]; + for (int xx = 0; xx < kshalf; xx++) + sum += (src[xx]+src[ks-xx-1])*kw[xx]; +#endif + if (sum < 0) + *dst = -(-sum >> ups_src_precision); + else + *dst = sum >> ups_src_precision; + } + } + + // v (vertical treatment) to output. + tmp.custom_pad_zero_plane_in_place_i(0, 0, pad); // TB pad. + int32_t *hsrc = tmp.pad_origin(0, 0, pad); + src = in.origin(); + dst = out.origin(); + int residue_shift = UPS_PRECISION-ups_src_precision; // src is arm output at lower precision. + for (int y = 0; y < in.h; y++, hsrc += tmp.stride-in.w, dst += out.stride-in.w, src += in.stride-in.w) + { + for (int x = 0; x < in.w; x++, hsrc++, dst++, src++) + { + int sum = 0; + for (int yy = 0; yy < ks; yy++) + sum += hsrc[yy*tmp.stride]*kw[yy]; + sum += (*src<> UPS_PRECISION); + else + *dst = sum >> UPS_PRECISION; + } + } +} + +#undef KS +#undef UPSNAME +#undef tostr +#undef xstr diff --git a/coolchic/cpp/ups_upsample_avx2.hpp b/coolchic/cpp/ups_upsample_avx2.hpp new file mode 100644 index 00000000..3171b8df --- /dev/null +++ b/coolchic/cpp/ups_upsample_avx2.hpp @@ -0,0 +1,211 @@ + +// defines expected: +// KS +// optional UPS_SRC_PRECISION +// UPSNAME + + +#define tostr(x) #x +#define xstr(x) tostr(x) + +// ups_src_precision is either ARM_PRECISION (from arm src, unrefined lowest layer) or UPS_PRECISION (from ups refinement src) +// incoming kernel is actually two filters, interleavead, to produce two outputs. +// tmp is guaranteed to hold a horizontally upsampled in. +void UPSNAME(int ksx2_param, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp) +{ +#ifdef KS + int const ksx2 = KS; + if (ksx2 != ksx2_param) + { + printf("%s: bad call, ksx2_param %d\n", xstr(UPSNAME), ksx2_param); + exit(1); + } +#else + int const ksx2 = ksx2_param; +#endif + int const ks = ksx2/2; // 2 kernels of size ks. + int pad = ks/2; // incoming ks is two even-sized filters. + in.custom_pad_replicate_plane_in_place_i(0, pad, 0); // LR pad. + +#ifdef UPS_SRC_PRECISION + if (ups_src_precision != UPS_SRC_PRECISION) + { + printf("%s: bad call: input precision is %d, not %d\n", xstr(UPSNAME), ups_src_precision, UPS_SRC_PRECISION); + exit(1); + } +#endif + +#ifdef KS + // kernel in registers + __m256i KW_EVEN[KS]; + __m256i KW_ODD[KS]; + for (int i = 0; i < KS; i++) + { + KW_EVEN[i] = _mm256_set1_epi32(kw[i*2]); + KW_ODD[i] = _mm256_set1_epi32(kw[i*2+1]); + } +#else + int32_t kw_even[ks]; + int32_t kw_odd[ks]; + for (int i = 0; i < ks; i++) + { + kw_even[i] = kw[i*2]; + kw_odd[i] = kw[i*2+1]; + } +#endif + + // prepare temporary to hold h x 2*w + tmp.update_to(in.h, 2*in.w, tmp.pad, tmp.planes); + int32_t *src = in.pad_origin(0, pad, 0); + int32_t *dst = tmp.origin(); + const __m256i z = _mm256_setzero_si256(); + + // h (horizontal scale) to temporary. + int xlim_blk = (in.w+7)/8; // pixels in number of horizontal blocks + int xlast_blk_size = 8-(xlim_blk*8 - in.w); // number of pixels to emit per filter in last block. + int32_t store[8]; + + for (int y = 0; y < in.h; y++, src += in.stride-xlim_blk*8, dst += tmp.stride-xlim_blk*8*2) + { + for (int x_blk = 0; x_blk < xlim_blk; x_blk++, src += 8, dst += 8*2) + { + int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size; + __m256i out_even_avx2 = _mm256_setzero_si256(); + __m256i out_odd_avx2 = _mm256_setzero_si256(); + for (int xx = 0; xx < ks; xx++) + { + __m256i input_even = _mm256_loadu_si256((__m256i_u*)&src[xx]); + __m256i input_odd = _mm256_loadu_si256((__m256i_u*)&src[xx+1]); // !!! strange we need a +1 for odd. +#ifdef KS + __m256i mul = _mm256_mullo_epi32(input_even, KW_EVEN[xx]); + out_even_avx2 = _mm256_add_epi32(out_even_avx2, mul); + mul = _mm256_mullo_epi32(input_odd, KW_ODD[xx]); + out_odd_avx2 = _mm256_add_epi32(out_odd_avx2, mul); +#else + __m256i kk = _mm256_set1_epi32(kw_even[xx]); + __m256i mul = _mm256_mullo_epi32(input_even, kk); + out_even_avx2 = _mm256_add_epi32(out_even_avx2, mul); + kk = _mm256_set1_epi32(kw_odd[xx]); + mul = _mm256_mullo_epi32(input_odd, kk); + out_odd_avx2 = _mm256_add_epi32(out_odd_avx2, mul); +#endif + } + // need different treatment for -ve -(-(>>)) and +ve (>>). +#ifdef UPS_SRC_PRECISION + __m256i sr = _mm256_srai_epi32(out_even_avx2, UPS_SRC_PRECISION); + __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_even_avx2), UPS_SRC_PRECISION)); + out_even_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_even_avx2, z)); + + sr = _mm256_srai_epi32(out_odd_avx2, UPS_SRC_PRECISION); + negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_odd_avx2), UPS_SRC_PRECISION)); + out_odd_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_odd_avx2, z)); +#else + __m256i sr = _mm256_srai_epi32(out_even_avx2, ups_src_precision); + __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_even_avx2), ups_src_precision)); + out_even_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_even_avx2, z)); + + sr = _mm256_srai_epi32(out_odd_avx2, ups_src_precision); + negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_odd_avx2), ups_src_precision)); + out_odd_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_odd_avx2, z)); +#endif + + // spread horizontally + __m256i outA0 = _mm256_unpacklo_epi32(out_even_avx2, out_odd_avx2); + __m256i outA1 = _mm256_unpackhi_epi32(out_even_avx2, out_odd_avx2); + // We need to remangle A0,A1 switching their high and low 128-bit halves. + out_even_avx2 = _mm256_permute2f128_si256(outA0, outA1, 0x20); + out_odd_avx2 = _mm256_permute2f128_si256(outA0, outA1, 0x31); + if (n_outs == 8) + { + _mm256_storeu_si256((__m256i_u*)(dst+0), out_even_avx2); + _mm256_storeu_si256((__m256i_u*)(dst+8), out_odd_avx2); + } + else + { + n_outs *= 2; + if (n_outs >= 8) + { + _mm256_storeu_si256((__m256i_u*)(dst+0), out_even_avx2); + if (n_outs >= 16) + _mm256_storeu_si256((__m256i_u*)(dst+8), out_odd_avx2); + else if (n_outs > 8) + { + // remainder over 8. + _mm256_storeu_si256((__m256i_u*)&store[0], out_odd_avx2); + memcpy(&dst[8], &store[0], (n_outs-8)*sizeof(store[0])); + } + } + else + { + // remainder over 0. + _mm256_storeu_si256((__m256i_u*)&store[0], out_even_avx2); + memcpy(&dst[0], &store[0], n_outs*sizeof(store[0])); + } + } + } + } + + // v (vertical) to output. + tmp.custom_pad_replicate_plane_in_place_i(0, 0, pad); // TB pad. + int32_t *hsrc = tmp.pad_origin(0, 0, pad); + dst = out.plane_origin(out_plane); + + xlim_blk = (tmp.w+7)/8; // pixels in number of horizontal blocks + xlast_blk_size = 8-(xlim_blk*8 - tmp.w); // number of pixels to emit per filter in last block. + + for (int y = 0; y < out.h; y +=2, hsrc += tmp.stride-xlim_blk*8, dst += out.stride-xlim_blk*8+out.stride) + { + for (int x_blk = 0; x_blk < xlim_blk; x_blk++, hsrc += 8, dst += 8) + { + int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size; + __m256i out_even_avx2 = _mm256_setzero_si256(); + __m256i out_odd_avx2 = _mm256_setzero_si256(); + for (int yy = 0; yy < ks; yy++) + { + __m256i input_even = _mm256_loadu_si256((__m256i_u*)&hsrc[yy*tmp.stride]); + __m256i input_odd = _mm256_loadu_si256((__m256i_u*)&hsrc[(yy+1)*tmp.stride]); // !!! strange we need a +1 for odd. +#ifdef KS + __m256i mul = _mm256_mullo_epi32(input_even, KW_EVEN[yy]); + out_even_avx2 = _mm256_add_epi32(out_even_avx2, mul); + mul = _mm256_mullo_epi32(input_odd, KW_ODD[yy]); + out_odd_avx2 = _mm256_add_epi32(out_odd_avx2, mul); +#else + __m256i kk = _mm256_set1_epi32(kw_even[yy]); + __m256i mul = _mm256_mullo_epi32(input_even, kk); + out_even_avx2 = _mm256_add_epi32(out_even_avx2, mul); + kk = _mm256_set1_epi32(kw_odd[yy]); + mul = _mm256_mullo_epi32(input_odd, kk); + out_odd_avx2 = _mm256_add_epi32(out_odd_avx2, mul); +#endif + } + // need different treatment for -ve -(-(>>)) and +ve (>>). + __m256i sr = _mm256_srai_epi32(out_even_avx2, UPS_PRECISION); + __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_even_avx2), UPS_PRECISION)); + out_even_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_even_avx2, z)); + + sr = _mm256_srai_epi32(out_odd_avx2, UPS_PRECISION); + negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_odd_avx2), UPS_PRECISION)); + out_odd_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_odd_avx2, z)); + + if (n_outs == 8) + { + _mm256_storeu_si256((__m256i_u*)(dst+0*out.stride), out_even_avx2); + _mm256_storeu_si256((__m256i_u*)(dst+1*out.stride), out_odd_avx2); + } + else + { + _mm256_storeu_si256((__m256i_u*)&store[0], out_even_avx2); + memcpy(&dst[0], &store[0], n_outs*sizeof(store[0])); + _mm256_storeu_si256((__m256i_u*)&store[0], out_odd_avx2); + memcpy(&dst[out.stride], &store[0], n_outs*sizeof(store[0])); + } + } + } +} + + +#undef KS +#undef UPS_SRC_PRECISION +#undef UPSNAME +#undef tostr +#undef xstr diff --git a/coolchic/cpp/ups_upsample_cpu.hpp b/coolchic/cpp/ups_upsample_cpu.hpp new file mode 100644 index 00000000..bb404604 --- /dev/null +++ b/coolchic/cpp/ups_upsample_cpu.hpp @@ -0,0 +1,97 @@ + +// defines expected: +// KS +// UPSNAME + +#define tostr(x) #x +#define xstr(x) tostr(x) + +// ups_src_precision is either ARM_PRECISION (from arm src, unrefined lowest layer) or UPS_PRECISION (from ups refinement src) +// incoming kernel is actually two filters, interleavead, to produce two outputs. +// tmp is guaranteed to hold a horizontally upsampled in. +void UPSNAME(int ksx2_param, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp) +{ +#ifdef KS + int const ksx2 = KS; + if (ksx2 != ksx2_param) + { + printf("%s: bad call, ksx2_param %d\n", xstr(UPSNAME), ksx2_param); + exit(1); + } +#else + int const ksx2 = ksx2_param; +#endif + int const ks = ksx2/2; // 2 kernels of size ks. + int pad = ks/2; // incoming ks is two even-sized filters. + in.custom_pad_replicate_plane_in_place_i(0, pad, 0); // LR pad. + + int32_t kw_even[ks]; + int32_t kw_odd[ks]; + for (int i = 0; i < ks; i++) + { + kw_even[i] = kw[i*2]; + kw_odd[i] = kw[i*2+1]; + } + + // prepare temporary to hold h x 2*w + tmp.update_to(in.h, 2*in.w, tmp.pad, tmp.planes); + int32_t *src = in.pad_origin(0, pad, 0); + int32_t *dst = tmp.origin(); + + // h (horizontal scale) to temporary. + for (int y = 0; y < in.h; y++, src += in.stride-in.w, dst += tmp.stride-2*in.w) + { + for (int x = 0; x < in.w; x++) + { + // even + int sum_even = 0; + int sum_odd = 0; + for (int xx = 0; xx < ks; xx++) + { + sum_even += src[xx]*kw_even[xx]; + sum_odd += src[xx+1]*kw_odd[xx]; // !!! strange we need a +1 for odd. + } + if (sum_even < 0) + *dst++ = -(-sum_even >> ups_src_precision); + else + *dst++ = sum_even >> ups_src_precision; + if (sum_odd < 0) + *dst++ = -(-sum_odd >> ups_src_precision); + else + *dst++ = sum_odd >> ups_src_precision; + src++; + } + } + + // v (vertical) to output. + tmp.custom_pad_replicate_plane_in_place_i(0, 0, pad); // TB pad. + src = tmp.pad_origin(0, 0, pad); + dst = out.plane_origin(out_plane); + for (int y = 0; y < out.h; y += 2, src += tmp.stride-out.w, dst += out.stride-out.w+out.stride) + { + for (int x = 0; x < out.w; x++, src++, dst++) + { + int sum_even = 0; + int sum_odd = 0; + for (int yy = 0; yy < ks; yy++) + { + sum_even += src[yy*tmp.stride]*kw_even[yy]; + sum_odd += src[(yy+1)*tmp.stride]*kw_odd[yy]; // !!! strange we need a +1 for odd. + } + if (sum_even < 0) + dst[0] = -(-sum_even >> UPS_PRECISION); + else + dst[0] = sum_even >> UPS_PRECISION; + if (sum_odd < 0) + dst[0+out.stride] = -(-sum_odd >> UPS_PRECISION); + else + dst[0+out.stride] = sum_odd >> UPS_PRECISION; + } + } +} + + +#undef KS +#undef UPSNAME +#undef tostr +#undef xstr diff --git a/coolchic/dec/nn.py b/coolchic/dec/nn.py index 1f38f370..2930eee4 100644 --- a/coolchic/dec/nn.py +++ b/coolchic/dec/nn.py @@ -7,11 +7,10 @@ # Authors: see CONTRIBUTORS.md -import numpy as np import torch import torch.nn as nn -from enc.utils.misc import FIXED_POINT_FRACTIONAL_MULT, DescriptorNN +from enc.utils.misc import DescriptorNN from CCLIB.ccencapi import cc_decode_wb @@ -40,7 +39,7 @@ def decode_network( Returns: nn.Module: The decoded module """ - have_bias = q_step_nn.bias > 0 + have_bias = bitstream_path.bias != "" # Instantiate two range coder objects to decode simultaneously weight and bias bac_ctx_weight = cc_decode_wb(bitstream_path.weight) @@ -49,12 +48,10 @@ def decode_network( loaded_param = {} for k, v in empty_module.named_parameters(): - if k.endswith('.w') or k.endswith('.weight'): - cur_scale = scale_nn.weight + if "weight" in k: cur_q_step = q_step_nn.weight cur_param = bac_ctx_weight.decode_wb_continue(len(v.flatten()), scale_nn.weight) - elif k.endswith('.b') or k.endswith('.bias'): - cur_scale = scale_nn.bias + elif "bias" in k and have_bias: cur_q_step = q_step_nn.bias cur_param = bac_ctx_bias.decode_wb_continue(len(v.flatten()), scale_nn.bias) else: @@ -68,5 +65,5 @@ def decode_network( if "arm" in bitstream_path.weight: empty_module.set_param_from_float(loaded_param) else: - empty_module.load_state_dict(loaded_param) + empty_module.load_state_dict(loaded_param, strict = have_bias) return empty_module diff --git a/coolchic/decode.py b/coolchic/decode.py index 1e489504..c5fcbd2f 100644 --- a/coolchic/decode.py +++ b/coolchic/decode.py @@ -18,9 +18,39 @@ if __name__ == "__main__": # =========================== Parse arguments =========================== # parser = argparse.ArgumentParser() - parser.add_argument( "--input", "-i", type=str, default="./bitstream.cool", help="Bitstream path.") - parser.add_argument( "--output", "-o", default="", help="output ppm (rgb) or yuv") - parser.add_argument( "--no_avx2", action='store_true', help="Disable AVX2 support") + parser.add_argument( + "--input", "-i", type=str, default="./bitstream.cool", help="Bitstream path." + ) + parser.add_argument("--output", "-o", default="", help="output ppm (rgb) or yuv") + parser.add_argument("--no_avx2", action="store_true", help="Disable AVX2 usage") + parser.add_argument( + "--verbosity", type=int, default=0, + help="" + "0 does not output anything ; " + "1 prints the runtime of each step ;" + "2 is for debug." + ) + parser.add_argument( + "--output_chroma_format", + type=int, + default=0, + help= + "Use 0 to infer this from the bitstream header. " + " " + "Otherwise, specify '420' or '444' to change the chroma sampling for the " + "YUV output. " + " " + " Useless for RGB." + ) + parser.add_argument( + "--output_bitdepth", + type=int, + default=0, + help= + "Use 0 to infer this from the bitstream header. " + " " + "Otherwise, specify an integer in [8, 16] to set the output bitdepth." + ) args = parser.parse_args() # =========================== Parse arguments =========================== # @@ -38,7 +68,21 @@ if use_avx2: from CCLIB.ccdecapi_avx2 import cc_decode_avx2 - print("Using AVX2 instructions for faster decoding") - cc_decode_avx2(args.input, args.output) + + if args.verbosity >= 2: + print("Using AVX2 instructions for faster decoding") + cc_decode_avx2( + args.input, + args.output, + args.output_bitdepth, + args.output_chroma_format, + args.verbosity, + ) else: - cc_decode_cpu(args.input, args.output) \ No newline at end of file + cc_decode_cpu( + args.input, + args.output, + args.output_bitdepth, + args.output_chroma_format, + args.verbosity, + ) diff --git a/coolchic/enc/bitstream/armint.py b/coolchic/enc/bitstream/armint.py new file mode 100644 index 00000000..9767145d --- /dev/null +++ b/coolchic/enc/bitstream/armint.py @@ -0,0 +1,277 @@ +# Software Name: Cool-Chic +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange +# SPDX-License-Identifier: BSD 3-Clause "New" +# +# This software is distributed under the BSD-3-Clause license. +# +# Authors: see CONTRIBUTORS.md + + +"""Fixed point implementation of the ARM to avoid floating point drift.""" + +from typing import OrderedDict, Tuple +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class ArmIntLinear(nn.Module): + """Create a Linear layer of the Auto-Regressive Module (ARM). This is a + wrapper around the usual ``nn.Linear`` layer of PyTorch, with a custom + initialization. It performs the following operations: + + * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b}` if + ``residual`` is ``False`` + + * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b} + + \\mathbf{x}_{in}` if ``residual`` is ``True``. + + The input :math:`\\mathbf{x}_{in}` is a :math:`[B, C_{in}]` tensor, the + output :math:`\\mathbf{x}_{out}` is a :math:`[B, C_{out}]` tensor. + + The layer weight and bias shapes are :math:`\\mathbf{W} \\in + \\mathbb{R}^{C_{out} \\times C_{in}}` and :math:`\\mathbf{b} \\in + \\mathbb{R}^{C_{out}}`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + fpfm: int = 0, + pure_int: bool = False, + residual: bool = False, + ): + """ + Args: + in_channels: Number of input features :math:`C_{in}`. + out_channels: Number of output features :math:`C_{out}`. + fpfm: Internal stuff for integer computation. **No need to modify + this**. Defaults to 0. + residual: True to add a residual connexion to the layer. Defaults to + False. + """ + + super().__init__() + + self.fpfm = fpfm + self.pure_int = pure_int + self.residual = residual + + # -------- Instantiate empty parameters, set by a later load + if self.pure_int: + self.weight = nn.Parameter( + torch.empty((out_channels, in_channels), dtype=torch.int32), + requires_grad=False, + ) + self.bias = nn.Parameter( + torch.empty((out_channels), dtype=torch.int32), requires_grad=False + ) + else: + self.weight = nn.Parameter( + torch.empty((out_channels, in_channels), dtype=torch.float), + requires_grad=False, + ) + self.bias = nn.Parameter( + torch.empty((out_channels), dtype=torch.float), requires_grad=False + ) + # -------- Instantiate empty parameters, set by a later load + + def forward(self, x: Tensor) -> Tensor: + """Perform the forward pass of this layer. + + Args: + x: Input tensor of shape :math:`[B, C_{in}]`. + + Returns: + Tensor with shape :math:`[B, C_{out}]`. + """ + if self.residual: + xx = F.linear(x, self.weight, bias=self.bias) + x * self.fpfm + else: + xx = F.linear(x, self.weight, bias=self.bias) + + # Renorm by fpfm after our (x*fpfm)*(qw*fpfm) multiplication. + # WE MAKE INTEGER DIVISION OBEY C++ (TO-ZERO) SEMANTICS, NOT PYTHON (TO-NEGATIVE-INFINITY) SEMANTICS + if self.pure_int: + xx = xx + torch.sign(xx) * self.fpfm // 2 + # We separate out -ve and non-ve. + neg_result = -((-xx) // self.fpfm) + pos_result = xx // self.fpfm + result = torch.where(xx < 0, neg_result, pos_result) + else: + xx = xx + torch.sign(xx) * self.fpfm / 2 + # We separate out -ve and non-ve. + neg_result = -((-xx) / self.fpfm) + pos_result = xx / self.fpfm + result = torch.where(xx < 0, neg_result, pos_result) + result = result.to(torch.int32).to(torch.float) + + return result + + +class ArmInt(nn.Module): + """Instantiate an autoregressive probability module, modelling the + conditional distribution :math:`p_{\\psi}(\\hat{y}_i \\mid + \\mathbf{c}_i)` of a (quantized) latent pixel :math:`\\hat{y}_i`, + conditioned on neighboring already decoded context pixels + :math:`\\mathbf{c}_i \in \\mathbb{Z}^C`, where :math:`C` denotes the + number of context pixels. + + The distribution :math:`p_{\\psi}` is assumed to follow a Laplace + distribution, parameterized by an expectation :math:`\\mu` and a scale + :math:`b`, where the scale and the variance :math:`\\sigma^2` are + related as follows :math:`\\sigma^2 = 2 b ^2`. + + The parameters of the Laplace distribution for a given latent pixel + :math:`\\hat{y}_i` are obtained by passing its context pixels + :math:`\\mathbf{c}_i` through an MLP :math:`f_{\\psi}`: + + .. math:: + + p_{\\psi}(\\hat{y}_i \\mid \\mathbf{c}_i) \sim \mathcal{L}(\\mu_i, + b_i), \\text{ where } \\mu_i, b_i = f_{\\psi}(\\mathbf{c}_i). + + .. attention:: + + The MLP :math:`f_{\\psi}` has a few constraint on its architecture: + + * The width of all hidden layers (i.e. the output of all layers except + the final one) are identical to the number of pixel contexts + :math:`C`; + + * All layers except the last one are residual layers, followed by a + ``ReLU`` non-linearity; + + * :math:`C` must be at a multiple of 8. + + The MLP :math:`f_{\\psi}` is made of custom Linear layers instantiated + from the ``ArmLinear`` class. + """ + + def __init__( + self, dim_arm: int, n_hidden_layers_arm: int, fpfm: int, pure_int: bool + ): + """ + Args: + dim_arm: Number of context pixels AND dimension of all hidden + layers :math:`C`. + n_hidden_layers_arm: Number of hidden layers. Set it to 0 for + a linear ARM. + """ + super().__init__() + + assert dim_arm % 8 == 0, ( + f"ARM context size and hidden layer dimension must be " + f"a multiple of 8. Found {dim_arm}." + ) + + self.FPFM = fpfm # fixed-point: multiplication to get int. + self.pure_int = pure_int # weights and biases are actual int (cpu only), or just int values in floats (gpu friendly). + + # ======================== Construct the MLP ======================== # + layers_list = nn.ModuleList() + + # Construct the hidden layer(s) + for i in range(n_hidden_layers_arm): + layers_list.append( + ArmIntLinear(dim_arm, dim_arm, self.FPFM, self.pure_int, residual=True) + ) + layers_list.append(nn.ReLU()) + + # Construct the output layer. It always has 2 outputs (mu and scale) + layers_list.append( + ArmIntLinear(dim_arm, 2, self.FPFM, self.pure_int, residual=False) + ) + self.mlp = nn.Sequential(*layers_list) + # ======================== Construct the MLP ======================== # + + def set_param_from_float(self, float_param: OrderedDict[str, Tensor]) -> None: + # We take floating point values here, and convert them to ints. + + # floating point params. We convert to fixed-point integer and store them. + integerised_param = {} + for k in float_param: + if "weight" in k: + float_v = float_param[k] * self.FPFM + else: + float_v = float_param[k] * self.FPFM * self.FPFM + + float_v = float_v + torch.sign(float_v) * 0.5 + neg_result = -(-float_v).to(torch.int32) + pos_result = float_v.to(torch.int32) + int_v = torch.where(float_v < 0, neg_result, pos_result) + if not self.pure_int: + int_v = int_v.to(torch.float) + integerised_param[k] = nn.parameter.Parameter(int_v, requires_grad=False) + + self.load_state_dict(integerised_param, assign=True) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Perform the auto-regressive module (ARM) forward pass. The ARM takes + as input a tensor of shape :math:`[B, C]` i.e. :math:`B` contexts with + :math:`C` context pixels. ARM outputs :math:`[B, 2]` values correspond + to :math:`\\mu, b` for each of the :math:`B` input pixels. + + .. warning:: + + Note that the ARM expects input to be flattened i.e. spatial + dimensions :math:`H, W` are collapsed into a single batch-like + dimension :math:`B = HW`, leading to an input of shape + :math:`[B, C]`, gathering the :math:`C` contexts for each of the + :math:`B` pixels to model. + + .. note:: + + The ARM MLP does not output directly the scale :math:`b`. Denoting + :math:`s` the raw output of the MLP, the scale is obtained as + follows: + + .. math:: + + b = e^{x - 4} + + Args: + x: Concatenation of all input contexts + :math:`\\mathbf{c}_i`. Tensor of shape :math:`[B, C]`. + + Returns: + Concatenation of all Laplace distributions param :math:`\\mu, b`. + Tensor of shape :math:([B]). Also return the *log scale* + :math:`s` as described above. Tensor of shape :math:`(B)` + """ + xint = x.clone().detach() + xint = xint * self.FPFM + if self.pure_int: + xint = xint.to(torch.int32) + + for idx_l, layer in enumerate(self.mlp.children()): + xint = layer(xint) + + # float the result. + raw_proba_param = xint / self.FPFM + + mu = raw_proba_param[:, 0] + log_scale = raw_proba_param[:, 1] + + # no scale smaller than exp(-4.6) = 1e-2 or bigger than exp(5.01) = 150 + scale = torch.exp(torch.clamp(log_scale - 4, min=-4.6, max=5.0)) + + return mu, scale, log_scale + + def get_param(self) -> OrderedDict[str, Tensor]: + """Return **a copy** of the weights and biases inside the module. + + Returns: + A copy of all weights & biases in the layers. + """ + # Detach & clone to create a copy + return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()}) + + def set_param(self, param: OrderedDict[str, Tensor]) -> None: + """Replace the current parameters of the module with param. + + Args: + param: Parameters to be set. + """ + self.load_state_dict(param) diff --git a/coolchic/enc/bitstream/encode.py b/coolchic/enc/bitstream/encode.py index 0ac8b1cd..01dc1511 100644 --- a/coolchic/enc/bitstream/encode.py +++ b/coolchic/enc/bitstream/encode.py @@ -6,7 +6,6 @@ # # Authors: see CONTRIBUTORS.md -import math import os import subprocess import time @@ -18,7 +17,7 @@ from enc.bitstream.utils import get_sub_bitstream_path from enc.bitstream.header import write_frame_header, write_gop_header from CCLIB.ccencapi import cc_code_latent_layer_bac, cc_code_wb_bac -from enc.component.core.arm import Arm, ArmInt +from enc.bitstream.armint import ArmInt from enc.component.core.synthesis import Synthesis from enc.component.core.upsampling import Upsampling from enc.component.frame import FrameEncoder @@ -54,7 +53,7 @@ def get_ac_max_val_nn(frame_encoder: FrameEncoder) -> int: # Retrieve all the weights and biases for the ARM MLP for k, v in module_to_encode.named_parameters(): - if k.endswith('.w') or k.endswith('.weight'): + if "weight" in k: cur_possible_q_step = POSSIBLE_Q_STEP.get(cur_module_name).get("weight") # Find the index of the closest quantization step in the list of @@ -74,12 +73,20 @@ def get_ac_max_val_nn(frame_encoder: FrameEncoder) -> int: torch.round((v/FIXED_POINT_FRACTIONAL_MULT) / cur_possible_q_step[cur_q_step_index]).flatten() ) else: + # No longer relevant without the bi-branch synthesis! + # # # Blending -- we get the transformed weight, not the underlying sigmoid parameter. + # # # plus: only if >1 branch. + # # if cur_module_name == "synthesis" and k.endswith(".parametrizations.weight.original"): + # # if "branch_blender" in k and frame_encoder.coolchic_encoder_param.n_synth_branch == 1: + # # continue # Do not emit unused blender weight. + # # xformed_weights = getattr(module_to_encode, k.replace(".parametrizations.weight.original", "")).weight + # # v = xformed_weights model_param_quant.append( torch.round(v / cur_possible_q_step[cur_q_step_index]).flatten() ) - elif k.endswith('.b') or k.endswith('.bias'): + elif "bias" in k: # Find the index of the closest quantization step in the list of # the possible quantization step. cur_possible_q_step = POSSIBLE_Q_STEP.get(cur_module_name).get("bias") @@ -158,9 +165,6 @@ def encode_video(video_encoder: VideoEncoder, bitstream_path: str, hls_sig_blksi # ======================== GOP HEADER ======================== # for idx_coding_order in range(video_encoder.coding_structure.get_number_of_frames()): - frame = video_encoder.coding_structure.get_frame_from_coding_order(idx_coding_order) - # assert frame.already_encoded, f'Frame {frame.display_order} has not been encoded yet!' - # Retrieve the frame encoder corresponding to the frame frame_encoder, _ = video_encoder.all_frame_encoders.get(str(idx_coding_order)) @@ -208,6 +212,12 @@ def encode_frame( frame_encoder.set_to_eval() frame_encoder.to_device('cpu') + # upsampling has bias parameters, but we do not use them. + have_bias = { "arm": True, + "upsampling": False, + "synthesis": True, + } + subprocess.call(f'rm -f {bitstream_path}', shell=True) # Load the references @@ -219,7 +229,6 @@ def encode_frame( # Move to pure-int Arm. Transfer the quantized weights from the fp Arm. arm_fp_param = frame_encoder.coolchic_encoder.arm.get_param() - print("recovered arm params", arm_fp_param.keys()) arm_int = ArmInt( frame_encoder.coolchic_encoder.param.dim_arm, frame_encoder.coolchic_encoder.param.n_hidden_layers_arm, @@ -228,7 +237,6 @@ def encode_frame( ) frame_encoder.coolchic_encoder.arm = arm_int frame_encoder.coolchic_encoder.arm.set_param_from_float(arm_fp_param) - print("set armint(pureint) params") # ================= Encode the MLP into a bitstream file ================ # ac_max_val_nn = get_ac_max_val_nn(frame_encoder) @@ -254,7 +262,7 @@ def encode_frame( Q_STEPS = POSSIBLE_Q_STEP.get(cur_module_name) - if k.endswith(".w") or k.endswith(".weight"): + if "weight" in k: # Find the index of the closest quantization step in the list of # the possible quantization step. cur_possible_q_step = POSSIBLE_Q_STEP.get(cur_module_name).get("weight") @@ -271,7 +279,6 @@ def encode_frame( # Quantize the weight with the actual quantization step and add it # to the list of (quantized) weights - # print(cur_module_name, k, v) if cur_module_name == "arm": # Our weights are stored as fixed point, we use shifts to get the integer values of quantized results. # Our int vals are int(floatval << FPFBITS) @@ -283,11 +290,19 @@ def encode_frame( v = torch.where(v < 0, neg_v, pos_v) weights.append(v.flatten()) else: + # No longer relevant without the bi-branch synth + # # # Blending -- we get the transformed weight, not the underlying sigmoid parameter. + # # # plus: only if >1 branch. + # # if cur_module_name == "synthesis" and k.endswith(".parametrizations.weight.original"): + # # if "branch_blender" in k and frame_encoder.coolchic_encoder_param.n_synth_branch == 1: + # # continue # Do not emit unused blender weight. + # # xformed_weights = getattr(module_to_encode, k.replace(".parametrizations.weight.original", "")).weight + # # v = xformed_weights weights.append( torch.round(v / cur_possible_q_step[cur_q_step_index]).flatten() ) - elif k.endswith(".b") or k.endswith(".bias"): + elif "bias" in k and have_bias[cur_module_name]: # Find the index of the closest quantization step in the list of # the Q_STEPS quantization step. cur_possible_q_step = POSSIBLE_Q_STEP.get(cur_module_name).get("bias") @@ -321,14 +336,15 @@ def encode_frame( # Gather them weights = torch.cat(weights).flatten() - have_bias = len(bias) != 0 - if have_bias: + if have_bias[cur_module_name]: bias = torch.cat(bias).flatten() + else: + q_step_index_nn[cur_module_name]['bias'] = 0 # we actually send this in the header. # ----------------- Actual entropy coding # It happens on cpu weights = weights.cpu() - if have_bias: + if have_bias[cur_module_name]: bias = bias.cpu() cur_bitstream_path = f'{bitstream_path}_{cur_module_name}_weight' @@ -346,7 +362,7 @@ def encode_frame( n_bytes_nn[cur_module_name]['weight'] = os.path.getsize(cur_bitstream_path) - if have_bias: + if have_bias[cur_module_name]: cur_bitstream_path = f'{bitstream_path}_{cur_module_name}_bias' # either code directly (normal), or search for best (backwards compatible). @@ -362,6 +378,7 @@ def encode_frame( n_bytes_nn[cur_module_name]['bias'] = os.path.getsize(cur_bitstream_path) else: + scale_index_nn[cur_module_name]['bias'] = 0 n_bytes_nn[cur_module_name]['bias'] = 0 # ================= Encode the MLP into a bitstream file ================ # @@ -386,18 +403,21 @@ def encode_frame( ) elif module_name == 'upsampling': empty_module = Upsampling( - frame_encoder.coolchic_encoder.param.upsampling_kernel_size, - frame_encoder.coolchic_encoder.param.static_upsampling_kernel + frame_encoder.coolchic_encoder.param.ups_k_size, + frame_encoder.coolchic_encoder.param.ups_preconcat_k_size, + # frame_encoder.coolchic_encoder.param.n_ups_kernel, + frame_encoder.coolchic_encoder.param.latent_n_grids - 1, + # frame_encoder.coolchic_encoder.param.n_ups_preconcat_kernel, + frame_encoder.coolchic_encoder.param.latent_n_grids - 1, ) Q_STEPS = POSSIBLE_Q_STEP.get(module_name) - have_bias = q_step_index_nn[module_name].get('bias') >= 0 loaded_module = decode_network( empty_module, DescriptorNN( weight = f'{bitstream_path}_{module_name}_weight', - bias = f'{bitstream_path}_{module_name}_bias' if have_bias else "", + bias = f'{bitstream_path}_{module_name}_bias' if have_bias[module_name] else "", ), DescriptorNN ( weight=Q_STEPS["weight"][q_step_index_nn[module_name]["weight"]], @@ -408,8 +428,7 @@ def encode_frame( bias=( scale_index_nn[module_name]["bias"] ) - if have_bias - else 0, + if have_bias[module_name] else 0, ), ac_max_val_nn ) @@ -456,7 +475,6 @@ def encode_frame( for index_lat_feature in range(c_i): y_this_ft = current_y[:, index_lat_feature, :, :].flatten().cpu() mu_this_ft = current_mu[:, index_lat_feature, :, :].flatten().cpu() - scale_this_ft = current_scale[:, index_lat_feature, :, :].flatten().cpu() log_scale_this_ft = current_log_scale[:, index_lat_feature, :, :].flatten().cpu() if y_this_ft.abs().max() == 0: diff --git a/coolchic/enc/bitstream/header.py b/coolchic/enc/bitstream/header.py index 63f0137f..0aeb88a9 100644 --- a/coolchic/enc/bitstream/header.py +++ b/coolchic/enc/bitstream/header.py @@ -98,7 +98,7 @@ # Quick & dirty copy and paste from utils.coding_structure _FRAME_DATA_TYPE = ["rgb", "yuv420", "yuv444"] -_POSSIBLE_BITDEPTH = [8, 10] +_POSSIBLE_BITDEPTH = [8, 9, 10, 11, 12, 13, 14, 15, 16] _POSSIBLE_SYNTHESIS_MODE = [k for k in Synthesis.possible_mode] _POSSIBLE_SYNTHESIS_NON_LINEARITY = [k for k in Synthesis.possible_non_linearity] @@ -107,7 +107,7 @@ class GopHeader(TypedDict): n_bytes_header: int # Number of bytes for the header img_size: Tuple[int, int] # Format: (height, width) frame_data_type: FRAME_DATA_TYPE # RGB, YUV 4:2:0, YUV 4:4:4 - bitdepth: POSSIBLE_BITDEPTH # 8 or 10 + bitdepth: POSSIBLE_BITDEPTH # 8 through 16 intra_period: int # See coding_structure.py p_period: int # See coding_structure.py @@ -291,8 +291,12 @@ def write_frame_header( n_bytes_header += ( 1 # Context size and Hidden layer dimension ARM, n. hidden layers ARM ) - n_bytes_header += 1 # Upsampling kernel size, static upsampling kernel flag - n_bytes_header += 1 # Number hidden layer Synthesis + + n_bytes_header += 1 # (n_ups_kernel << 4)|(ups_k_size) + n_bytes_header += 1 # (n_ups_preconcat_kernel << 4)|(ups_preconcat_k_size) + + n_bytes_header += 1 # Number of synthesis branches + n_bytes_header += 1 # Number hidden layer Synthesis per branch # Hidden Synthesis layer out#, kernelsz, mode+nonlinearity n_bytes_header += 3 * len(frame_encoder.coolchic_encoder_param.layers_synthesis) n_bytes_header += 1 # Flow gain @@ -357,17 +361,20 @@ def write_frame_header( + frame_encoder.coolchic_encoder_param.n_hidden_layers_arm ).to_bytes(1, byteorder="big", signed=False) - assert frame_encoder.coolchic_encoder_param.upsampling_kernel_size < 2**7, ( - f"Upsampling" - f" kernel size should be small than {2 ** 7}. Found" - f" {frame_encoder.coolchic_encoder_param.upsampling_kernel_size}" - ) - byte_to_write += ( - frame_encoder.coolchic_encoder_param.upsampling_kernel_size * 2**1 - + int(frame_encoder.coolchic_encoder_param.static_upsampling_kernel) + # (frame_encoder.coolchic_encoder_param.n_ups_kernel<<4)|(frame_encoder.coolchic_encoder_param.ups_k_size) + ((frame_encoder.coolchic_encoder_param.latent_n_grids-1)<<4)|(frame_encoder.coolchic_encoder_param.ups_k_size) + ).to_bytes(1, byteorder="big", signed=False) + byte_to_write += ( + # (frame_encoder.coolchic_encoder_param.n_ups_preconcat_kernel<<4)|(frame_encoder.coolchic_encoder_param.ups_preconcat_k_size) + ((frame_encoder.coolchic_encoder_param.latent_n_grids-1)<<4)|(frame_encoder.coolchic_encoder_param.ups_preconcat_k_size) ).to_bytes(1, byteorder="big", signed=False) + # Continue to send this byte for compatibility + _dummy_n_synth_branch = 1 + byte_to_write += ( + _dummy_n_synth_branch # frame_encoder.coolchic_encoder_param.n_synth_branch + ).to_bytes(1, byteorder="big", signed=False) byte_to_write += len( frame_encoder.coolchic_encoder_param.layers_synthesis ).to_bytes(1, byteorder="big", signed=False) @@ -464,161 +471,3 @@ def write_frame_header( print("expected", n_bytes_header) print("got", os.path.getsize(header_path)) exit(1) - - -def read_frame_header(bitstream: bytes) -> FrameHeader: - """Read the first few bytes of a bitstream file located at - and parse the different information. - - Args: - bitstream_path (str): Path where the bitstream is located. - - Returns: - FrameHeader: The parsed info from the bitstream. - """ - - ptr = 0 - n_bytes_header = int.from_bytes( - bitstream[ptr : ptr + 2], byteorder="big", signed=False - ) - ptr += 2 - - display_index = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - ptr += 1 - - raw_arm = int.from_bytes(bitstream[ptr : ptr + 1], byteorder="big", signed=False) - ptr += 1 - dim_arm = 8 * (raw_arm // (2**4)) - n_hidden_layers_arm = raw_arm % (2**4) - - raw_upsampling = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - ptr += 1 - upsampling_kernel_size = raw_upsampling // (2**1) - static_upsampling_kernel = bool(raw_upsampling % (2**1)) - - n_hidden_dim_synthesis = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - ptr += 1 - layers_synthesis = [] - for i in range(n_hidden_dim_synthesis): - out_ft = int.from_bytes(bitstream[ptr : ptr + 1], byteorder="big", signed=False) - ptr += 1 - kernel_size = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - ptr += 1 - mode_non_linearity = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - ptr += 1 - mode = _POSSIBLE_SYNTHESIS_MODE[mode_non_linearity // 16] - non_linearity = _POSSIBLE_SYNTHESIS_NON_LINEARITY[mode_non_linearity % 16] - layers_synthesis.append(f"{out_ft}-{kernel_size}-{mode}-{non_linearity}") - - flow_gain = int.from_bytes(bitstream[ptr : ptr + 1], byteorder="big", signed=False) - ptr += 1 - - ac_max_val_nn = int.from_bytes( - bitstream[ptr : ptr + 2], byteorder="big", signed=False - ) - ptr += 2 - ac_max_val_latent = int.from_bytes( - bitstream[ptr : ptr + 2], byteorder="big", signed=False - ) - ptr += 2 - hls_sig_blksize = int.from_bytes(bitstream[ptr: ptr + 1], byteorder='big', signed=True) - ptr += 1 - - q_step_index_nn: DescriptorCoolChic = {} - for nn_name in ["arm", "upsampling", "synthesis"]: - q_step_index_nn[nn_name] = {} - for nn_param in ["weight", "bias"]: - q_step_index_nn[nn_name][nn_param] = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - # # # Hack -- 255 -> -1 for upsampling bias. Indicating no bias. - # # if q_step_index_nn[nn_name][nn_param] == 255: - # # q_step_index_nn[nn_name][nn_param] = -1 - # # print("got -1:", nn_name, nn_param) - ptr += 1 - - scale_index_nn: DescriptorCoolChic = {} - for nn_name in ["arm", "upsampling", "synthesis"]: - scale_index_nn[nn_name] = {} - for nn_param in ["weight", "bias"]: - # if q_step_index_nn[nn_name][nn_param] < 0: - # scale_index_nn[nn_name][nn_param] = -1 - # else: - scale_index_nn[nn_name][nn_param] = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - ptr += 1 - - n_bytes_nn: DescriptorCoolChic = {} - for nn_name in ["arm", "upsampling", "synthesis"]: - n_bytes_nn[nn_name] = {} - for nn_param in ["weight", "bias"]: - # if q_step_index_nn[nn_name][nn_param] < 0: - # n_bytes_nn[nn_name][nn_param] = -1 - # else: - n_bytes_nn[nn_name][nn_param] = int.from_bytes( - bitstream[ptr : ptr + 2], byteorder="big", signed=False - ) - ptr += 2 - - latent_n_resolutions = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - ptr += 1 - latent_n_2d_grid = int.from_bytes( - bitstream[ptr : ptr + 1], byteorder="big", signed=False - ) - ptr += 1 - - n_ft_per_latent = [] - for _ in range(latent_n_resolutions): - n_ft_per_latent.append( - int.from_bytes(bitstream[ptr : ptr + 1], byteorder="big", signed=False) - ) - ptr += 1 - - n_bytes_per_latent = [] - for _ in range(latent_n_2d_grid): - n_bytes_per_latent.append( - int.from_bytes(bitstream[ptr : ptr + 3], byteorder="big", signed=False) - ) - ptr += 3 - - header_info: FrameHeader = { - "n_bytes_header": n_bytes_header, - "latent_n_resolutions": latent_n_resolutions, - "latent_n_2d_grid": latent_n_2d_grid, - "n_bytes_per_latent": n_bytes_per_latent, - "n_ft_per_latent": n_ft_per_latent, - "n_hidden_layers_arm": n_hidden_layers_arm, - "dim_arm": dim_arm, - "upsampling_kernel_size": upsampling_kernel_size, - "static_upsampling_kernel": static_upsampling_kernel, - "flow_gain": flow_gain, - "layers_synthesis": layers_synthesis, - "q_step_index_nn": q_step_index_nn, - "scale_index_nn": scale_index_nn, - "n_bytes_nn": n_bytes_nn, - "ac_max_val_nn": ac_max_val_nn, - "ac_max_val_latent": ac_max_val_latent, - "hls_sig_blksize": hls_sig_blksize, - "display_index": display_index, - } - - print("\nContent of the frame header:") - print("------------------------------") - for k, v in header_info.items(): - print(f"{k:>20}: {v}") - print(" ------------------------") - - return header_info diff --git a/coolchic/enc/component/coolchic.py b/coolchic/enc/component/coolchic.py index 13ee545a..dfb6a7e7 100644 --- a/coolchic/enc/component/coolchic.py +++ b/coolchic/enc/component/coolchic.py @@ -11,10 +11,12 @@ from dataclasses import dataclass, field, fields from typing import Any, Dict, List, Optional, OrderedDict, Tuple, TypedDict +from enc.visu.console import pretty_string_nn, pretty_string_ups from torch import nn, Tensor import torch from fvcore.nn import FlopCountAnalysis, flop_count_table + from enc.component.core.arm import ( Arm, _get_neighbor, @@ -72,11 +74,13 @@ class CoolChicEncoderParameter: n_hidden_layers_arm (int, Optional): Number of hidden layers in the ARM. Set ``n_hidden_layers_arm = 0`` for a linear ARM. Defaults to 2. - upsampling_kernel_size (int, Optional): Kernel size for the upsampler. - See the :doc:`upsampling documentation ` for more - information. Defaults to 8. - static_upsampling_kernel (bool, Optional): Set this flag to ``True`` to - prevent learning the upsampling kernel. Defaults to ``False``. + ups_k_size (int, Optional): Upsampling kernel size for the transposed + convolutions. See the :doc:`upsampling documentation ` + for more information. Defaults to 8. + ups_preconcat_k_size (int, Optional): Upsampling kernel size for the + pre-concatenation convolutions. See the + :doc:`upsampling documentation ` for more + information. Defaults to 7. encoder_gain (int, Optional): Multiply the latent by this value before quantization. See the documentation of Cool-chic forward pass. Defaults to 16. @@ -85,9 +89,9 @@ class CoolChicEncoderParameter: n_ft_per_res: List[int] dim_arm: int = 24 n_hidden_layers_arm: int = 2 - upsampling_kernel_size: int = 8 - static_upsampling_kernel: bool = False encoder_gain: int = 16 + ups_k_size: int = 8 + ups_preconcat_k_size: int = 7 # ==================== Not set by the init function ===================== # #: Automatically computed, number of different latent resolutions @@ -192,7 +196,13 @@ def __init__(self, param: CoolChicEncoderParameter): # ===================== Upsampling stuff ===================== # self.upsampling = Upsampling( - self.param.upsampling_kernel_size, self.param.static_upsampling_kernel + ups_k_size=self.param.ups_k_size, + ups_preconcat_k_size=self.param.ups_preconcat_k_size, + # Instantiate one different upsampling and pre-concatenation + # filters for each of the upsampling step. Could also be set to one + # to share the same filter across all latents. + n_ups_kernel=self.param.latent_n_grids - 1, + n_ups_preconcat_kernel=self.param.latent_n_grids - 1, ) # ===================== Upsampling stuff ===================== # @@ -239,17 +249,19 @@ def __init__(self, param: CoolChicEncoderParameter): self.arm = Arm(self.param.dim_arm, self.param.n_hidden_layers_arm) # ===================== ARM related stuff ==================== # + # Something like ['arm', 'synthesis', 'upsampling'] + self.modules_to_send = [tmp.name for tmp in fields(DescriptorCoolChic)] + # ======================== Monitoring ======================== # # Pretty string representing the decoder complexity self.flops_str = "" # Total number of multiplications to decode the image self.total_flops = 0.0 + self.flops_per_module = {k: 0 for k in self.modules_to_send} # Fill the two attributes aboves self.get_flops() # ======================== Monitoring ======================== # - # Something like ['arm', 'synthesis', 'upsampling'] - self.modules_to_send = [tmp.name for tmp in fields(DescriptorCoolChic)] # Track the quantization step of each neural network, None if the # module is not yet quantized @@ -413,6 +425,7 @@ def forward( additional_data["detailed_log_scale"] = [] additional_data["detailed_rate_bit"] = [] additional_data["detailed_centered_latent"] = [] + additional_data["hpfilters"] = [] # "Pointer" for the reading of the 1D scale, mu and rate cnt = 0 @@ -586,6 +599,9 @@ def get_flops(self) -> None: # print("Ignoring get_flops") # Count the number of floating point operations here. It must be done before # torch scripting the different modules. + + self = self.train(mode=False) + flops = FlopCountAnalysis( self, ( @@ -601,9 +617,14 @@ def get_flops(self) -> None: flops.uncalled_modules_warnings(False) self.total_flops = flops.total() + for k in self.flops_per_module: + self.flops_per_module[k] = flops.by_module()[k] + self.flops_str = flop_count_table(flops) del flops + self = self.train(mode=True) + def get_network_rate(self) -> DescriptorCoolChic: """Return the rate (in bits) associated to the parameters (weights and biases) of the different modules @@ -706,3 +727,63 @@ def to_device(self, device: POSSIBLE_DEVICE) -> None: if hasattr(layer, "qb"): if layer.qb is not None: self.arm.mlp[idx_layer].qb = layer.qb.to(device) + + def pretty_string(self) -> str: + """Get a pretty string representing the layer of a ``CoolChicEncoder``""" + + s = "" + + if not self.flops_str: + self.get_flops() + + n_pixels = self.param.img_size[-2] * self.param.img_size[-1] + total_mac_per_pix = self.get_total_mac_per_pixel() + + + title = f"Cool-chic architecture {total_mac_per_pix:.0f} MAC / pixel" + s += ( + f"\n{title}\n" + f"{'-' * len(title)}\n\n" + ) + + complexity = self.flops_per_module['upsampling'] / n_pixels + share_complexity = 100 * complexity / total_mac_per_pix + title = f"Upsampling {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity" + s += ( + f"{title}\n" + f"{'=' * len(title)}\n" + "Note: all upsampling layers are separable and symmetric " + "(transposed) convolutions.\n\n" + + ) + s += pretty_string_ups(self.upsampling, "") + + complexity = self.flops_per_module['arm'] / n_pixels + share_complexity = 100 * complexity / total_mac_per_pix + title = f"ARM {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity" + s += ( + f"\n\n\n{title}\n" + f"{'=' * len(title)}\n\n\n" + + ) + input_arm = f"{self.arm.dim_arm}-pixel context" + output_arm = "mu, log scale" + s += pretty_string_nn( + self.arm.mlp, "", input_arm, output_arm + ) + + complexity = self.flops_per_module['synthesis'] / n_pixels + share_complexity = 100 * complexity / total_mac_per_pix + title = f"Synthesis {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity" + s += ( + f"\n\n\n{title}\n" + f"{'=' * len(title)}\n\n\n" + + ) + input_syn = f"{self.synthesis.input_ft} features" + output_syn = "Decoded image" + s += pretty_string_nn( + self.synthesis.layers, "", input_syn, output_syn + ) + + return s diff --git a/coolchic/enc/component/core/arm.py b/coolchic/enc/component/core/arm.py index c5999faf..0878fd4f 100644 --- a/coolchic/enc/component/core/arm.py +++ b/coolchic/enc/component/core/arm.py @@ -50,6 +50,8 @@ def __init__( super().__init__() self.residual = residual + self.in_channels = in_channels + self.out_channels = out_channels # -------- Instantiate empty parameters, set by the initialize function self.weight = nn.Parameter( @@ -95,94 +97,6 @@ def forward(self, x: Tensor) -> Tensor: else: return F.linear(x, self.weight, bias=self.bias) -class ArmIntLinear(nn.Module): - """Create a Linear layer of the Auto-Regressive Module (ARM). This is a - wrapper around the usual ``nn.Linear`` layer of PyTorch, with a custom - initialization. It performs the following operations: - - * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b}` if - ``residual`` is ``False`` - - * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b} + - \\mathbf{x}_{in}` if ``residual`` is ``True``. - - The input :math:`\\mathbf{x}_{in}` is a :math:`[B, C_{in}]` tensor, the - output :math:`\\mathbf{x}_{out}` is a :math:`[B, C_{out}]` tensor. - - The layer weight and bias shapes are :math:`\\mathbf{W} \\in - \\mathbb{R}^{C_{out} \\times C_{in}}` and :math:`\\mathbf{b} \\in - \\mathbb{R}^{C_{out}}`. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - fpfm: int = 0, - pure_int: bool = False, - residual: bool = False, - ): - """ - Args: - in_channels: Number of input features :math:`C_{in}`. - out_channels: Number of output features :math:`C_{out}`. - fpfm: Internal stuff for integer computation. **No need to modify - this**. Defaults to 0. - residual: True to add a residual connexion to the layer. Defaults to - False. - """ - - super().__init__() - - self.fpfm = fpfm - self.pure_int = pure_int - self.residual = residual - - # -------- Instantiate empty parameters, set by a later load - if self.pure_int: - self.weight = nn.Parameter( - torch.empty((out_channels, in_channels), dtype=torch.int32), requires_grad=False - ) - self.bias = nn.Parameter(torch.empty((out_channels), dtype=torch.int32), requires_grad=False) - else: - self.weight = nn.Parameter( - torch.empty((out_channels, in_channels), dtype=torch.float), requires_grad=False - ) - self.bias = nn.Parameter(torch.empty((out_channels), dtype=torch.float), requires_grad=False) - # -------- Instantiate empty parameters, set by a later load - - - def forward(self, x: Tensor) -> Tensor: - """Perform the forward pass of this layer. - - Args: - x: Input tensor of shape :math:`[B, C_{in}]`. - - Returns: - Tensor with shape :math:`[B, C_{out}]`. - """ - if self.residual: - xx = F.linear(x, self.weight, bias=self.bias) + x*self.fpfm - else: - xx = F.linear(x, self.weight, bias=self.bias) - - # Renorm by fpfm after our (x*fpfm)*(qw*fpfm) multiplication. - # WE MAKE INTEGER DIVISION OBEY C++ (TO-ZERO) SEMANTICS, NOT PYTHON (TO-NEGATIVE-INFINITY) SEMANTICS - if self.pure_int: - xx = xx + torch.sign(xx)*self.fpfm//2 - # We separate out -ve and non-ve. - neg_result = -((-xx)//self.fpfm) - pos_result = xx//self.fpfm - result = torch.where(xx < 0, neg_result, pos_result) - else: - xx = xx + torch.sign(xx)*self.fpfm/2 - # We separate out -ve and non-ve. - neg_result = -((-xx)/self.fpfm) - pos_result = xx/self.fpfm - result = torch.where(xx < 0, neg_result, pos_result) - result = result.to(torch.int32).to(torch.float) - - return result class Arm(nn.Module): """Instantiate an autoregressive probability module, modelling the @@ -223,7 +137,6 @@ class Arm(nn.Module): from the ``ArmLinear`` class. """ - def __init__(self, dim_arm: int, n_hidden_layers_arm: int): """ Args: @@ -238,6 +151,7 @@ def __init__(self, dim_arm: int, n_hidden_layers_arm: int): f"ARM context size and hidden layer dimension must be " f"a multiple of 8. Found {dim_arm}." ) + self.dim_arm = dim_arm # ======================== Construct the MLP ======================== # layers_list = nn.ModuleList() @@ -317,165 +231,6 @@ def reinitialize_parameters(self) -> None: if isinstance(layer, ArmLinear): layer.initialize_parameters() -class ArmInt(nn.Module): - """Instantiate an autoregressive probability module, modelling the - conditional distribution :math:`p_{\\psi}(\\hat{y}_i \\mid - \\mathbf{c}_i)` of a (quantized) latent pixel :math:`\\hat{y}_i`, - conditioned on neighboring already decoded context pixels - :math:`\\mathbf{c}_i \in \\mathbb{Z}^C`, where :math:`C` denotes the - number of context pixels. - - The distribution :math:`p_{\\psi}` is assumed to follow a Laplace - distribution, parameterized by an expectation :math:`\\mu` and a scale - :math:`b`, where the scale and the variance :math:`\\sigma^2` are - related as follows :math:`\\sigma^2 = 2 b ^2`. - - The parameters of the Laplace distribution for a given latent pixel - :math:`\\hat{y}_i` are obtained by passing its context pixels - :math:`\\mathbf{c}_i` through an MLP :math:`f_{\\psi}`: - - .. math:: - - p_{\\psi}(\\hat{y}_i \\mid \\mathbf{c}_i) \sim \mathcal{L}(\\mu_i, - b_i), \\text{ where } \\mu_i, b_i = f_{\\psi}(\\mathbf{c}_i). - - .. attention:: - - The MLP :math:`f_{\\psi}` has a few constraint on its architecture: - - * The width of all hidden layers (i.e. the output of all layers except - the final one) are identical to the number of pixel contexts - :math:`C`; - - * All layers except the last one are residual layers, followed by a - ``ReLU`` non-linearity; - - * :math:`C` must be at a multiple of 8. - - The MLP :math:`f_{\\psi}` is made of custom Linear layers instantiated - from the ``ArmLinear`` class. - """ - - def __init__(self, dim_arm: int, n_hidden_layers_arm: int, fpfm: int, pure_int: bool): - """ - Args: - dim_arm: Number of context pixels AND dimension of all hidden - layers :math:`C`. - n_hidden_layers_arm: Number of hidden layers. Set it to 0 for - a linear ARM. - """ - super().__init__() - - assert dim_arm % 8 == 0, ( - f"ARM context size and hidden layer dimension must be " - f"a multiple of 8. Found {dim_arm}." - ) - - self.FPFM = fpfm # fixed-point: multiplication to get int. - self.pure_int = pure_int # weights and biases are actual int (cpu only), or just int values in floats (gpu friendly). - - # ======================== Construct the MLP ======================== # - layers_list = nn.ModuleList() - - # Construct the hidden layer(s) - for i in range(n_hidden_layers_arm): - layers_list.append(ArmIntLinear(dim_arm, dim_arm, self.FPFM, self.pure_int, residual=True)) - layers_list.append(nn.ReLU()) - - # Construct the output layer. It always has 2 outputs (mu and scale) - layers_list.append(ArmIntLinear(dim_arm, 2, self.FPFM, self.pure_int, residual=False)) - self.mlp = nn.Sequential(*layers_list) - # ======================== Construct the MLP ======================== # - - def set_param_from_float(self, float_param: OrderedDict[str, Tensor]) -> None: - # We take floating point values here, and convert them to ints. - - # floating point params. We convert to fixed-point integer and store them. - integerised_param = {} - for k in float_param: - if "weight" in k: - float_v = float_param[k]*self.FPFM - else: - float_v = float_param[k]*self.FPFM*self.FPFM - - float_v = float_v + torch.sign(float_v)*0.5 - neg_result = -(-float_v).to(torch.int32) - pos_result = float_v.to(torch.int32) - int_v = torch.where(float_v < 0, neg_result, pos_result) - if not self.pure_int: - int_v = int_v.to(torch.float) - integerised_param[k] = nn.parameter.Parameter(int_v, requires_grad=False) - - self.load_state_dict(integerised_param, assign=True) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """Perform the auto-regressive module (ARM) forward pass. The ARM takes - as input a tensor of shape :math:`[B, C]` i.e. :math:`B` contexts with - :math:`C` context pixels. ARM outputs :math:`[B, 2]` values correspond - to :math:`\\mu, b` for each of the :math:`B` input pixels. - - .. warning:: - - Note that the ARM expects input to be flattened i.e. spatial - dimensions :math:`H, W` are collapsed into a single batch-like - dimension :math:`B = HW`, leading to an input of shape - :math:`[B, C]`, gathering the :math:`C` contexts for each of the - :math:`B` pixels to model. - - .. note:: - - The ARM MLP does not output directly the scale :math:`b`. Denoting - :math:`s` the raw output of the MLP, the scale is obtained as - follows: - - .. math:: - - b = e^{x - 4} - - Args: - x: Concatenation of all input contexts - :math:`\\mathbf{c}_i`. Tensor of shape :math:`[B, C]`. - - Returns: - Concatenation of all Laplace distributions param :math:`\\mu, b`. - Tensor of shape :math:([B]). Also return the *log scale* - :math:`s` as described above. Tensor of shape :math:`(B)` - """ - xint = x.clone().detach() - xint = xint*self.FPFM - if self.pure_int: - xint = xint.to(torch.int32) - - for idx_l, layer in enumerate(self.mlp.children()): - xint = layer(xint) - - # float the result. - raw_proba_param = xint / self.FPFM - - mu = raw_proba_param[:, 0] - log_scale = raw_proba_param[:, 1] - - # no scale smaller than exp(-4.6) = 1e-2 or bigger than exp(5.01) = 150 - scale = torch.exp(torch.clamp(log_scale - 4, min=-4.6, max=5.0)) - - return mu, scale, log_scale - - def get_param(self) -> OrderedDict[str, Tensor]: - """Return **a copy** of the weights and biases inside the module. - - Returns: - A copy of all weights & biases in the layers. - """ - # Detach & clone to create a copy - return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()}) - - def set_param(self, param: OrderedDict[str, Tensor]) -> None: - """Replace the current parameters of the module with param. - - Args: - param: Parameters to be set. - """ - self.load_state_dict(param) @torch.jit.script def _get_neighbor(x: Tensor, mask_size: int, non_zero_pixel_ctx_idx: Tensor) -> Tensor: @@ -565,7 +320,7 @@ def _get_non_zero_pixel_ctx_index(dim_arm: int) -> Tensor: Returns: Tensor: 1D tensor with the flattened index of the context pixels. """ - + # fmt: off if dim_arm == 8: return torch.tensor( [ 13, @@ -606,3 +361,4 @@ def _get_non_zero_pixel_ctx_index(dim_arm: int) -> Tensor: 36, 37, 38, 39, # ] ) + # fmt: on diff --git a/coolchic/enc/component/core/quantizer.py b/coolchic/enc/component/core/quantizer.py index 6193c2f7..25646187 100644 --- a/coolchic/enc/component/core/quantizer.py +++ b/coolchic/enc/component/core/quantizer.py @@ -169,7 +169,6 @@ def quantize( Quantized tensor """ # ----- Check user input - # TODO: How long is it to do such assert? assert quantizer_noise_type in typing.get_args(POSSIBLE_QUANTIZATION_NOISE_TYPE), ( f"quantizer_noise_type must be in {POSSIBLE_QUANTIZATION_NOISE_TYPE}" f" found {quantizer_noise_type}" @@ -226,7 +225,6 @@ def quantize( # From the forward point of view (i.e. entering into the torch.no_grad()), we have # y = softround(x) - softround(x) + round(x) = round(x). From the backward point of view # we have y = softround(x) meaning that dy / dx = d softround(x) / dx. - # TODO: check whether it works? y = softround(x, soft_round_temperature) with torch.no_grad(): y = y - softround(x, soft_round_temperature) + torch.round(x) diff --git a/coolchic/enc/component/core/synthesis.py b/coolchic/enc/component/core/synthesis.py index 0d56ad06..a4440fca 100644 --- a/coolchic/enc/component/core/synthesis.py +++ b/coolchic/enc/component/core/synthesis.py @@ -46,8 +46,11 @@ def __init__( """ super().__init__() - self.pad = int((kernel_size - 1) / 2) self.residual = residual + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.pad = int((kernel_size - 1) / 2) # -------- Instantiate empty parameters, set by the initialize function self.groups = 1 # Hardcoded for now @@ -168,6 +171,9 @@ def __init__(self, input_ft: int, layers_dim: List[str]): following the notation detailed above. """ super().__init__() + + self.synth_branches = nn.ModuleList() + self.input_ft = input_ft layers_list = nn.ModuleList() # Construct the hidden layer(s) @@ -211,6 +217,7 @@ def forward(self, x: Tensor) -> Tensor: """ return self.layers(x) + def get_param(self) -> OrderedDict[str, Tensor]: """Return **a copy** of the weights and biases inside the module. @@ -229,7 +236,7 @@ def set_param(self, param: OrderedDict[str, Tensor]): self.load_state_dict(param) def reinitialize_parameters(self) -> None: - """Re-initialize in place the parameters of all the SynthesisConv2d layer.""" + """Re-initialize in place the params of all the ``SynthesisConv2d`` layers.""" for layer in self.layers.children(): if isinstance(layer, SynthesisConv2d): layer.initialize_parameters() diff --git a/coolchic/enc/component/core/upsampling.py b/coolchic/enc/component/core/upsampling.py index f74ae844..c3b9b4f5 100644 --- a/coolchic/enc/component/core/upsampling.py +++ b/coolchic/enc/component/core/upsampling.py @@ -11,127 +11,293 @@ import torch import torch.nn.functional as F +import torch.nn.utils.parametrize as parametrize from einops import rearrange from torch import Tensor, nn -class UpsamplingConvTranspose2d(nn.Module): - """Wrapper around the usual ``nn.TransposeConv2d`` layer. It performs a 2x - upsampling of a latent variable with a **single** input and output channel. - It can be learned or not, depending on the flag - ``static_upsampling_kernel``. Its initialization depends on the requested - kernel size. If the kernel size is 4 or 6, we use the bilinear kernel with - zero padding if necessary. Otherwise, if the kernel size is 8 or bigger, we - rely on the bicubic kernel. - """ +class _Parameterization_Symmetric_1d(nn.Module): + """This module is not meant to be instantiated. It should rather be used + through the ``torch.nn.utils.parametrize.register_parametrization()`` + function to reparameterize a N-element vector into a 2N-element (or 2N+1) + symmetric vector. For instance: - kernel_bilinear = torch.tensor( - [ - [0.0625, 0.1875, 0.1875, 0.0625], - [0.1875, 0.5625, 0.5625, 0.1875], - [0.1875, 0.5625, 0.5625, 0.1875], - [0.0625, 0.1875, 0.1875, 0.0625], - ] - ) - - kernel_bicubic = torch.tensor( - [ - [ 0.0012359619 , 0.0037078857 ,-0.0092010498 ,-0.0308990479 ,-0.0308990479 ,-0.0092010498 , 0.0037078857 , 0.0012359619], - [ 0.0037078857 , 0.0111236572 ,-0.0276031494 ,-0.0926971436 ,-0.0926971436 ,-0.0276031494 , 0.0111236572 , 0.0037078857], - [-0.0092010498 ,-0.0276031494 , 0.0684967041 , 0.2300262451 , 0.2300262451 , 0.0684967041 ,-0.0276031494 ,-0.0092010498], - [-0.0308990479 ,-0.0926971436 , 0.2300262451 , 0.7724761963 , 0.7724761963 , 0.2300262451 ,-0.0926971436 ,-0.0308990479], - [-0.0308990479 ,-0.0926971436 , 0.2300262451 , 0.7724761963 , 0.7724761963 , 0.2300262451 ,-0.0926971436 ,-0.0308990479], - [-0.0092010498 ,-0.0276031494 , 0.0684967041 , 0.2300262451 , 0.2300262451 , 0.0684967041 ,-0.0276031494 ,-0.0092010498], - [ 0.0037078857 , 0.0111236572 ,-0.0276031494 ,-0.0926971436 ,-0.0926971436 ,-0.0276031494 , 0.0111236572 , 0.0037078857], - [ 0.0012359619 , 0.0037078857 ,-0.0092010498 ,-0.0308990479 ,-0.0308990479 ,-0.0092010498 , 0.0037078857 , 0.0012359619], - ] - ) + * x = a b c and target_k_size = 5 --> a b c b a + * x = a b c and target_k_size = 6 --> a b c c b a + Both these 5-element or 6-element vectors can be parameterize through + a 3-element representation (a, b, c). + """ - def __init__( - self, - upsampling_kernel_size: int, - static_upsampling_kernel: bool - ): + def __init__(self, target_k_size: int): """ Args: - upsampling_kernel_size: Upsampling kernel size. Should be >= 4 - and a multiple of two. - static_upsampling_kernel: If true, don't learn the upsampling kernel. + target_k_size: Target size of the kernel after reparameterization. """ - super().__init__() - assert upsampling_kernel_size >= 4, ( - f"Upsampling kernel size should be >= 4." f"Found {upsampling_kernel_size}" + super().__init__() + self.target_k_size = target_k_size + self.param_size = _Parameterization_Symmetric_1d.size_param_from_target( + self.target_k_size ) - assert upsampling_kernel_size % 2 == 0, ( - f"Upsampling kernel size should be even." f"Found {upsampling_kernel_size}" + def forward(self, x: Tensor) -> Tensor: + """Return a longer, symmetric vector by concatenating x with a flipped + version of itself. + + Args: + x (Tensor): [N] tensor. + + Returns: + Tensor: [2N] or [2N + 1] tensor, depending on self.target_k_size + """ + + # torch.fliplr requires to have a 2D kernel + x_reversed = torch.fliplr(x.view(1, -1)).view(-1) + + kernel = torch.cat( + [ + x, + # a b c c b a if n is even or a b c b a if n is odd + x_reversed[self.target_k_size % 2 :], + ], ) - self.upsampling_kernel_size = upsampling_kernel_size - self.static_upsampling_kernel = static_upsampling_kernel + return kernel + + + @classmethod + def size_param_from_target(cls, target_k_size: int) -> int: + """Return the size of the appropriate parameterization of a + symmetric tensor with target_k_size elements. For instance: + + target_k_size = 6 ; parameterization size = 3 e.g. (a b c c b a) + + target_k_size = 7 ; parameterization size = 4 e.g. (a b c d c b a) + + Args: + target_k_size (int): Size of the actual symmetric 1D kernel. + + Returns: + int: Size of the underlying parameterization. + """ + # For a kernel of size target_k_size = 2N, we need N values + # e.g. 3 params a b c to parameterize a b c c b a. + # For a kernel of size target_k_size = 2N + 1, we need N + 1 values + # e.g. 4 params a b c d to parameterize a b c d c b a. + return (target_k_size + 1) // 2 + + + +class UpsamplingSeparableSymmetricConv2d(nn.Module): + """ + A conv2D which has a separable and symmetric *odd* kernel. + + Separable means that the 2D-kernel :math:`\mathbf{w}_{2D}` can be expressed + as the outer product of a 1D kernel :math:`\mathbf{w}_{1D}`: + + .. math:: + + \mathbf{w}_{2D} = \mathbf{w}_{1D} \otimes \mathbf{w}_{1D}. + + The 1D kernel :math:`\mathbf{w}_{1D}` is also symmetric. That is, the 1D + kernel is something like :math:`\mathbf{w}_{1D} = \left(a\ b\ c\ b\ a\ + \\right).` + + The symmetric constraint is obtained through the module + ``_Parameterization_Symmetric_1d``. The separable constraint is obtained by + calling twice the 1D kernel. + """ + def __init__(self, kernel_size: int): + """ + kernel_size: Size of the kernel :math:`\mathbf{w}_{1D}` e.g. 7 to + obtain a symmetrical, separable 7x7 filter. Must be odd! + """ + super().__init__() + + assert ( + kernel_size % 2 == 1 + ), f"Upsampling kernel size must be odd, found {kernel_size}." + + self.target_k_size = kernel_size + self.param_size = _Parameterization_Symmetric_1d.size_param_from_target( + self.target_k_size + ) # -------- Instantiate empty parameters, set by the initialize function self.weight = nn.Parameter( - torch.empty(1, 1, upsampling_kernel_size, upsampling_kernel_size), - requires_grad=True, + torch.empty(self.param_size), requires_grad=True ) - self.bias = nn.Parameter(torch.empty((1)), requires_grad=True) + + self.bias = nn.Parameter(torch.empty(1), requires_grad=True) self.initialize_parameters() # -------- Instantiate empty parameters, set by the initialize function - # Keep initial weights if required by the self.static_upsampling kernel flag - if self.static_upsampling_kernel: - # register_buffer for automatic device management. We set persistent to false - # to simply use the "automatically move to device" function, without - # considering non_zero_pixel_ctx_index as a parameters (i.e. returned - # by self.parameters()) - self.register_buffer("static_kernel", self.weight.data.clone(), persistent=False) - else: - self.static_kernel = None + # Each time we call .weight, we'll call the forward of + # _Parameterization_Symmetric_1d to get a symmetric kernel. + parametrize.register_parametrization( + self, + "weight", + _Parameterization_Symmetric_1d(target_k_size=self.target_k_size), + # Unsafe because we change the data dimension, from N to 2N + 1 + unsafe=True, + ) def initialize_parameters(self) -> None: """ - Initialize **in-place ** the weights and the biases of the transposed - convolution layer performing the upsampling. + Initialize the weights and the biases of the transposed convolution + layer performing the upsampling. - - Biases are always set to zero. + * Biases are always set to zero. - - Weights are set to a (padded) bicubic kernel if kernel size is at - least 8. If kernel size is greater than or equal to 4, weights are - set to a (padded) bilinear kernel. + * Weights are set to :math:`(0,\ 0,\ 0,\ \ldots, 1)` so that when the + symmetric reparameterization is applied a Dirac kernel is obtained e.g. + :math:`(0,\ 0,\ 0,\ \ldots, 1, \ldots, 0,\ 0,\ 0,)`. """ - # -------- bias is always set to zero (and in fact never ever used) + # Zero everywhere except for the last coef + w = torch.zeros_like(self.weight) + w[-1] = 1 + self.weight = nn.Parameter(w, requires_grad=True) + self.bias = nn.Parameter(torch.zeros_like(self.bias), requires_grad=True) - # -------- Weights are initialized to bicubic or bilinear - # adapted filter size - K = self.upsampling_kernel_size - self.upsampling_padding = (K // 2, K // 2, K // 2, K // 2) - self.upsampling_crop = (3 * K - 2) // 2 + def forward(self, x: Tensor) -> Tensor: + """Perform a "normal" 2D convolution, except that the underlying kernel + is both separable & symmetrical. The actual implementation of the forward + depends on ``self.training``. + + If we're training, we use a non-separable implementation. That is, we + first compute the 2D kernel through an outer product and then use a + single 2D convolution. This is more stable. - if K < 8: - kernel_init = UpsamplingConvTranspose2d.kernel_bilinear + If we're not training, we use two successive 1D convolutions. + + .. warning:: + + There is a residual connexion in the forward. + + Args: + x: [B, 1, H, W] tensor to be filtered. Must have one + only channel. + + Returns: + Tensor: Filtered tensor [B, 1, H, W]. + """ + k = self.weight.size()[0] + weight = self.weight.view(1, -1) + padding = k // 2 + + # Train using non-separable (more stable) + if self.training: + # Kronecker product of (1 k) & (k 1) --> (k, k). + # Then, two dummy dimensions are added to be compliant with conv2d + # (k, k) --> (1, 1, k, k). + kernel_2d = torch.kron(weight, weight.T).view((1, 1, k, k)) + + # ! Note the residual connexion! + return F.conv2d(x, kernel_2d, bias=None, stride=1, padding=padding) + x + + # Test through separable (less complex, for the flop counter) else: - kernel_init = UpsamplingConvTranspose2d.kernel_bicubic - - # pad initial filter according to desired kernel size - tmpad = (K - kernel_init.size()[0]) // 2 - upsampling_kernel = F.pad( - kernel_init.clone().detach(), - (tmpad, tmpad, tmpad, tmpad), - mode="constant", - value=0.0, + yw = F.conv2d(x, weight.view((1, 1, 1, k)), padding=(0, padding)) + + # ! Note the residual connexion! + return F.conv2d(yw, weight.view((1, 1, k, 1)), padding=(padding, 0)) + x + + +class UpsamplingSeparableSymmetricConvTranspose2d(nn.Module): + """ + A TransposedConv2D which has a separable and symmetric *even* kernel. + + Separable means that the 2D-kernel :math:`\mathbf{w}_{2D}` can be expressed + as the outer product of a 1D kernel :math:`\mathbf{w}_{1D}`: + + .. math:: + + \mathbf{w}_{2D} = \mathbf{w}_{1D} \otimes \mathbf{w}_{1D}. + + The 1D kernel :math:`\mathbf{w}_{1D}` is also symmetric. That is, the 1D + kernel is something like :math:`\mathbf{w}_{1D} = \left(a\ b\ c\ c\ b\ a\ + \\right).` + + The symmetric constraint is obtained through the module + ``_Parameterization_Symmetric_1d``. The separable constraint is obtained by + calling twice the 1D kernel. + """ + + def __init__(self, kernel_size: int): + """ + Args: + kernel_size: Upsampling kernel size. Shall be even and >= 4. + """ + super().__init__() + + assert kernel_size >= 4 and not kernel_size % 2, ( + f"Upsampling kernel size shall be even and โ‰ฅ4. Found {kernel_size}" ) - # 4D kernel to be compatible with transpose convolution - upsampling_kernel = rearrange(upsampling_kernel, "k_h k_w -> 1 1 k_h k_w") - self.weight = nn.Parameter(upsampling_kernel, requires_grad=True) + self.target_k_size = kernel_size + self.param_size = _Parameterization_Symmetric_1d.size_param_from_target( + self.target_k_size + ) + + # -------- Instantiate empty parameters, set by the initialize function + self.weight = nn.Parameter( + torch.empty(self.param_size), requires_grad=True + ) + + self.bias = nn.Parameter(torch.empty(1), requires_grad=True) + self.initialize_parameters() + # -------- Instantiate empty parameters, set by the initialize function + + # Each time we call .weight, we'll call the forward of + # _Parameterization_Symmetric_1d to get a symmetric kernel. + parametrize.register_parametrization( + self, + "weight", + _Parameterization_Symmetric_1d(target_k_size=self.target_k_size), + # Unsafe because we change the data dimension, from N to 2N + 1 + unsafe=True, + ) + + def initialize_parameters(self) -> None: + """Initialize the parameters of a + ``UpsamplingSeparableSymmetricConvTranspose2d`` layer. + + * Biases are always set to zero. + + * Weights are initialize as a (possibly padded) bilinear filter when + ``target_k_size`` is 4 or 6, otherwise a bicubic filter is used. + """ + # For a target kernel size of 4 or 6, we use a bilinear kernel as the + # initialization. For bigger kernels, a bicubic kernel is used. In both + # case we just initialize the left half of the kernel since these + # filters are symmetrical + if self.target_k_size < 8: + kernel_core = torch.tensor([1.0 / 4.0, 3.0 / 4.0]) + else: + kernel_core = torch.tensor([0.0351562, 0.1054687, -0.2617187, -0.8789063]) + + # If target_k_size = 6, then param_size = 3 while kernel_core = 2 + # Thus we need to add zero_pad = 1 to the left of the kernel. + zero_pad = self.param_size - kernel_core.size()[0] + w = torch.zeros_like(self.weight) + w[zero_pad:] = kernel_core + self.weight = nn.Parameter(w, requires_grad=True) + + self.bias = nn.Parameter(torch.zeros_like(self.bias), requires_grad=True) def forward(self, x: Tensor) -> Tensor: """Perform the spatial upsampling (with scale 2) of an input with a - single channel. + single channel. Note that the upsampling filter is both symmetrical and + separable. The actual implementation of the forward depends on + ``self.training``. + + If we're training, we use a non-separable implementation. That is, we + first compute the 2D kernel through an outer product and then use a + single 2D convolution. This is more stable. + + If we're not training, we use two successive 1D convolutions. Args: x: Single channel input with shape :math:`(B, 1, H, W)` @@ -139,23 +305,47 @@ def forward(self, x: Tensor) -> Tensor: Returns: Upsampled version of the input with shape :math:`(B, 1, 2H, 2W)` """ - upsampling_weight = ( - self.static_kernel if self.static_upsampling_kernel else self.weight - ) - x_pad = F.pad(x, self.upsampling_padding, mode="replicate") - y_conv = F.conv_transpose2d(x_pad, upsampling_weight, stride=2) + k = self.target_k_size # kernel size + P0 = k // 2 # could be 0 or k//2 as in legacy implementation + C = 2 * P0 - 1 + k // 2 # crop side border k - 1 + k//2 (k=4, C=5 k=8, C=11) + + weight = self.weight.view(1, -1) + + if self.training: # training using non-separable (more stable) + kernel_2d = (torch.kron(weight, weight.T).view((1, 1, k, k))) + + x_pad = F.pad(x, (P0, P0, P0, P0), mode="replicate") + yc = F.conv_transpose2d(x_pad, kernel_2d, stride=2) + + # crop to remove padding in convolution + H, W = yc.size()[-2:] + y = yc[ + :, + :, + C : H - C, + C : W - C, + ] + + else: # testing through separable (less complex) + # horizontal filtering + x_pad = F.pad(x, (P0, P0, 0, 0), mode="replicate") + yc = F.conv_transpose2d(x_pad, weight.view((1, 1, 1, k)), stride=(1, 2)) + W = yc.size()[-1] + y = yc[ + :, + :, + :, + C : W - C, + ] - # crop to remove padding in convolution - H, W = y_conv.size()[-2:] - results = y_conv[ - :, - :, - self.upsampling_crop : H - self.upsampling_crop, - self.upsampling_crop : W - self.upsampling_crop, - ] + # vertical filtering + x_pad = F.pad(y, (0, 0, P0, P0), mode="replicate") + yc = F.conv_transpose2d(x_pad, weight.view((1, 1, k, 1)), stride=(2, 1)) + H = yc.size()[-2] + y = yc[:, :, C : H - C, :] - return results + return y class Upsampling(nn.Module): @@ -176,42 +366,103 @@ class Upsampling(nn.Module): \hat{\mathbf{z}} \\in \\mathbb{R}^{C \\times H \\times W} \\text { and } C = \\sum_i C_i. - The upsampling relies on a single custom transpose convolution - ``UpsamplingConvTranspose2d`` performing a 2x upsampling of a 1-channel - input. This transpose convolution is called over and over to upsampling - each channel of each resolution until they reach the required :math:`H - \\times W` dimensions. + For a toy example with 3 latent grids (``--n_ft_per_res=1,1,1``), the + overall diagram of the upsampling is as follows. + + .. code:: + + +---------+ + y0 -> | TConv2d | -----+ + +---------+ | + v + +--------+ +-----+ +---------+ + y1 -> | Conv2d | -> | cat | -> | TConv2d | -----+ + +--------+ +-----+ +---------+ | + v + +--------+ +-----+ +---------+ + y2 ----------------------------> | Conv2d | -> | cat | -> | TConv2d | -> dense + +--------+ +-----+ +---------+ + + Where ``y0`` has the smallest resolution, ``y1`` has a resolution double of + ``y0`` etc. + + There are two different sets of filters: + + * The TConvs filters actually perform the x2 upsampling. They are + referred to as upsampling filters. Implemented using + ``UpsamplingSeparableSymmetricConvTranspose2d``. + + * The Convs filters pre-process the signal prior to concatenation. They + are referred to as pre-concatenation filters. Implemented using + ``UpsamplingSeparableSymmetricConv2d``. + + Kernel sizes for the upsampling and pre-concatenation filters are modified + through the ``--ups_k_size`` and ``--ups_preconcat_k_size`` arguments. + + Each upsampling filter and each pre-concatenation filter is different. They + are all separable and symmetrical. - The kernel of the ``UpsamplingConvTranspose2d`` depending on the value - of the flag ``static_upsampling_kernel``. In either case, the kernel - initialization is based on well-known bilinear or bicubic kernel - depending on the requested ``upsampling_kernel_size``: + Upsampling convolutions are initialized with a bilinear or bicubic kernel + depending on the required requested ``ups_k_size``: - * If ``upsampling_kernel_size >= 4 and upsampling_kernel_size < 8``, a + * If ``ups_k_size >= 4 and ups_k_size < 8``, a bilinear kernel (with zero padding if necessary) is used an initialization. - * If ``upsampling_kernel_size >= 8``, a bicubic kernel (with zero padding if + * If ``ups_k_size >= 8``, a bicubic kernel (with zero padding if necessary) is used an initialization. - .. warning:: + Pre-concatenation convolutions are initialized with a Dirac kernel. - The ``upsampling_kernel_size`` must be at least 4 and a multiple of 2. - """ + .. warning:: + + * The ``ups_k_size`` must be at least 4 and a multiple of 2. - def __init__(self, upsampling_kernel_size: int, static_upsampling_kernel: bool): + * The ``ups_preconcat_k_size`` must be odd. + """ + def __init__( + self, + ups_k_size: int, + ups_preconcat_k_size: int, + n_ups_kernel: int, + n_ups_preconcat_kernel: int, + ): """ Args: - upsampling_kernel_size: Upsampling kernel size. Should be bigger or - equal to 4 and a multiple of two. - static_upsampling_kernel: If true, don't learn the upsampling - kernel. + ups_k_size: Upsampling (TransposedConv) kernel size. Should be + even and >= 4. + ups_preconcat_k_size: Pre-concatenation kernel size. Should be odd. + n_ups_kernel: Number of different upsampling kernels. Usually it is + set to the number of latent - 1 (because the full resolution + latent is not upsampled). But this can also be set to one to + share the same kernel across all variables. + n_ups_preconcat_kernel: Number of different pre-concatenation + filters. Usually it is set to the number of latent - 1 (because + the smallest resolution is not filtered prior to concat). + But this can also be set to one to share the same kernel across + all variables. """ super().__init__() - self.conv_transpose2d = UpsamplingConvTranspose2d( - upsampling_kernel_size, static_upsampling_kernel + # number of kernels for the lower and higher branches + self.n_ups_kernel = n_ups_kernel + self.n_ups_preconcat_kernel = n_ups_preconcat_kernel + + # Upsampling kernels = transpose conv2d + self.conv_transpose2ds = nn.ModuleList( + [ + UpsamplingSeparableSymmetricConvTranspose2d(ups_k_size) + for _ in range(n_ups_kernel) + ] + ) + + # Pre concatenation filters = conv2d + self.conv2ds = nn.ModuleList( + [ + UpsamplingSeparableSymmetricConv2d(ups_preconcat_k_size) + for _ in range(self.n_ups_preconcat_kernel) + ] ) def forward(self, decoder_side_latent: List[Tensor]) -> Tensor: @@ -231,14 +482,19 @@ def forward(self, decoder_side_latent: List[Tensor]) -> Tensor: # so that the same convolution is applied independently on the batch dimension. latent_reversed = list(reversed(decoder_side_latent)) upsampled_latent = latent_reversed[0] # start from smallest - for target_tensor in latent_reversed[1:]: + + for idx, target_tensor in enumerate(latent_reversed[1:]): # Our goal is to upsample to the same resolution than x = rearrange(upsampled_latent, "b c h w -> (b c) 1 h w") - x = self.conv_transpose2d(x) + x = self.conv_transpose2ds[idx % self.n_ups_kernel](x) + x = rearrange(x, "(b c) 1 h w -> b c h w", b=upsampled_latent.shape[0]) # Crop to comply with higher resolution feature maps size before concatenation x = x[:, :, : target_tensor.shape[-2], : target_tensor.shape[-1]] - upsampled_latent = torch.cat((target_tensor, x), dim=1) + + high_branch = self.conv2ds[idx % self.n_ups_preconcat_kernel](target_tensor) + upsampled_latent = torch.cat((high_branch, x), dim=1) + return upsampled_latent def get_param(self) -> OrderedDict[str, Tensor]: @@ -260,4 +516,7 @@ def set_param(self, param: OrderedDict[str, Tensor]): def reinitialize_parameters(self) -> None: """Re-initialize **in place** the parameters of the upsampling.""" - self.conv_transpose2d.initialize_parameters() + for i in range(len(self.conv_transpose2ds)): + self.conv_transpose2d[i].initialize_parameters() + for i in range(len(self.conv2ds)): + self.conv2ds[i].initialize_parameters() diff --git a/coolchic/enc/component/frame.py b/coolchic/enc/component/frame.py index 9b69c3dd..da3a77f4 100644 --- a/coolchic/enc/component/frame.py +++ b/coolchic/enc/component/frame.py @@ -7,7 +7,7 @@ # Authors: see CONTRIBUTORS.md -"""A frame encoder is composed of a CoolChicEncoder and a InterCodingModule.""" +"""A frame encoder is composed of a CoolChicEncoder and an InterCodingModule.""" import typing from dataclasses import dataclass, field @@ -24,16 +24,11 @@ POSSIBLE_QUANTIZER_TYPE, ) from enc.component.intercoding import InterCodingModule -from torch import Tensor, nn -from enc.utils.codingstructure import ( - FRAME_DATA_TYPE, - FRAME_TYPE, - POSSIBLE_BITDEPTH, - DictTensorYUV, - convert_444_to_420, -) +from enc.io.format.data_type import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH +from enc.io.format.yuv import DictTensorYUV, convert_444_to_420, yuv_dict_clamp +from enc.utils.codingstructure import FRAME_TYPE from enc.utils.misc import POSSIBLE_DEVICE -from enc.utils.yuv import yuv_dict_clamp +from torch import Tensor, nn @dataclass @@ -274,15 +269,15 @@ def save(self) -> BytesIO: } if self.coolchic_encoder.full_precision_param is not None: - data_to_save["coolchic_full_precision_param"] = self.coolchic_encoder.full_precision_param + data_to_save["coolchic_full_precision_param"] = ( + self.coolchic_encoder.full_precision_param + ) torch.save(data_to_save, buffer) - # for k, v in self.coolchic_encoder.get_param().items(): - # print(f"{k:>30}: {v.abs().sum().item()}") - return buffer + def load_frame_encoder(raw_bytes: BytesIO) -> FrameEncoder: """From already loaded raw bytes, load & return a CoolChicEncoder @@ -295,7 +290,7 @@ def load_frame_encoder(raw_bytes: BytesIO) -> FrameEncoder: """ # Reset the stream position to the beginning of the BytesIO object & load it raw_bytes.seek(0) - loaded_data = torch.load(raw_bytes, map_location="cpu") + loaded_data = torch.load(raw_bytes, map_location="cpu", weights_only=False) # Create a frame encoder from the stored parameters frame_encoder = FrameEncoder( @@ -311,9 +306,13 @@ def load_frame_encoder(raw_bytes: BytesIO) -> FrameEncoder: # Check if coolchic_nn_expgol_cnt is present in loaded data for backward # compatibility. Not meant to stay very long. if "coolchic_nn_expgol_cnt" in loaded_data: - frame_encoder.coolchic_encoder.nn_expgol_cnt = loaded_data["coolchic_nn_expgol_cnt"] + frame_encoder.coolchic_encoder.nn_expgol_cnt = loaded_data[ + "coolchic_nn_expgol_cnt" + ] if "coolchic_full_precision_param" in loaded_data: - frame_encoder.coolchic_encoder.full_precision_param = loaded_data["coolchic_full_precision_param"] + frame_encoder.coolchic_encoder.full_precision_param = loaded_data[ + "coolchic_full_precision_param" + ] return frame_encoder diff --git a/coolchic/enc/component/video.py b/coolchic/enc/component/video.py index 7affdf96..6588ddc9 100644 --- a/coolchic/enc/component/video.py +++ b/coolchic/enc/component/video.py @@ -23,7 +23,7 @@ from enc.training.warmup import warmup from enc.utils.codingstructure import CodingStructure, Frame, FrameData from enc.utils.misc import POSSIBLE_DEVICE, TrainingExitCode, is_job_over, mem_info -from enc.utils.yuv import load_frame_data_from_file +from enc.io.io import load_frame_data_from_file class VideoEncoder(): @@ -168,6 +168,8 @@ def encode( + "-" * 80 ) + print("\n" + frame.data.to_string() + "\n") + # ----- Set the parameters for the frame frame_encoder_manager = copy.deepcopy( self.shared_frame_encoder_manager @@ -226,6 +228,8 @@ def encode( f_out.write(str(list_candidates[0].coolchic_encoder) + "\n\n") f_out.write(list_candidates[0].coolchic_encoder.str_complexity() + "\n") + print(list_candidates[0].coolchic_encoder.pretty_string() + "\n\n") + # Use warm-up to find the best initialization among the list # of candidates parameters. frame_encoder = warmup( @@ -502,7 +506,7 @@ def load_video_encoder(load_path: str) -> VideoEncoder: """ print(f"Loading a video encoder from {load_path}") - raw_data = torch.load(load_path, map_location="cpu") + raw_data = torch.load(load_path, map_location="cpu", weights_only=False) # Calling the VideoEncoder constructor automatically reload the # original frames. diff --git a/coolchic/enc/io/format/data_type.py b/coolchic/enc/io/format/data_type.py new file mode 100644 index 00000000..32c36c3b --- /dev/null +++ b/coolchic/enc/io/format/data_type.py @@ -0,0 +1,14 @@ +# Software Name: Cool-Chic +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange +# SPDX-License-Identifier: BSD 3-Clause "New" +# +# This software is distributed under the BSD-3-Clause license. +# +# Authors: see CONTRIBUTORS.md + + +from typing import Literal + + +FRAME_DATA_TYPE = Literal["rgb", "yuv420", "yuv444"] +POSSIBLE_BITDEPTH = Literal[8, 9, 10, 11, 12, 13, 14, 15, 16] diff --git a/coolchic/enc/io/format/png.py b/coolchic/enc/io/format/png.py new file mode 100644 index 00000000..d7332046 --- /dev/null +++ b/coolchic/enc/io/format/png.py @@ -0,0 +1,50 @@ +# Software Name: Cool-Chic +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange +# SPDX-License-Identifier: BSD 3-Clause "New" +# +# This software is distributed under the BSD-3-Clause license. +# +# Authors: see CONTRIBUTORS.md + + +import os +from typing import Tuple + +from einops import rearrange +import torch +from enc.io.format.data_type import POSSIBLE_BITDEPTH +from PIL import Image +from torch import Tensor +from torchvision.transforms.functional import to_pil_image, to_tensor + + +def read_png(file_path: str) -> Tuple[Tensor, POSSIBLE_BITDEPTH]: + """Read a PNG file + + Args: + file_path: Path of the png file to read. + + Returns: + Image data [1, 3, H, W] in [0., 1.] and its bitdepth. + """ + assert os.path.isfile(file_path), f"No file found at {file_path}." + + data = to_tensor(Image.open(file_path)) + data = rearrange(data, "c h w -> 1 c h w") + + # Bitdepth is always 8 when we read PNG through PIL? + bitdepth = 8 + + return data, bitdepth + + +@torch.no_grad() +def write_png(data: Tensor, file_path: str) -> None: + """Save an image x into a PNG file. + + Args: + x: Image to be saved + file_path: Where to save the PNG files + """ + data = rearrange(data, "1 c h w -> c h w", c=3) + to_pil_image(data).save(file_path) diff --git a/coolchic/enc/io/format/ppm.py b/coolchic/enc/io/format/ppm.py new file mode 100644 index 00000000..88e236e7 --- /dev/null +++ b/coolchic/enc/io/format/ppm.py @@ -0,0 +1,205 @@ +# Software Name: Cool-Chic +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange +# SPDX-License-Identifier: BSD 3-Clause "New" +# +# This software is distributed under the BSD-3-Clause license. +# +# Authors: see CONTRIBUTORS.md + + +import math +import os +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor + +from enc.io.format.data_type import POSSIBLE_BITDEPTH + + +def _skip_one_byte(data: bytearray) -> bytearray: + """Skip one byte in a byte array and return the array + with its first byte removed. + + Args: + data: Input byte array. Length is N + + Returns: + bytearray: Output byte array. Length is N - 1 + """ + return data[1:] + + +def _read_int_until_blank(data: bytearray) -> Tuple[int, bytearray]: + """Parse an ASCII int until running into one of the space characters. + Also return the input byte array where the bytes corresponding to + both the int and the space character are skipped. + + 132\ncds# --> Return the value 123 and a byte array containing cds# + + + As defined by the ANSI standard C isspace(s) returns True for + Horizontal tab (HT), line feed (LF), vertical tabulation (VT), + form feed (FF), carriage return (CR) and white space. + + Note: this function may fail if the bytes collected up to the space + character do not represent ascii numbers. + + Args: + data: Input data. + + Returns: + Parsed int + input data where we've removed the bytes corresponding to + the int value and the space character. + """ + # As defined by the ANSI standard C isspace(s) + # ASCII code for Horizontal tab (HT), line feed (LF), + # vertical tabulation (VT), form feed (FF), carriage return (CR) + # and white space + _BLANKS_ASCII = [9, 10, 11, 12, 13, 32] + + ptr_end = 0 + while data[ptr_end] not in _BLANKS_ASCII: + ptr_end += 1 + + value = int(data[:ptr_end].decode("utf-8")) + data = data[ptr_end:] + return value, data + + +def _16bits_byte_swap(data: Tensor) -> Tensor: + """Invert the bytes composing a 2-byte value. The actual data type + of the tensor is not important but it must contains value in + [0, 2 ** 16 - 1] + + For instance: + + 1111 1111 0000 0010 ==> 0000 0010 1111 1111 + \______/ \______/ \______/ \______/ + MSB LSB LSB MSB + + Args: + data: Tensor to be swapped. + + Returns: + Swapped tensor. + """ + + msb = data // 2**8 + lsb = data % 2**8 + swapped_data = lsb * 2**8 + msb + return swapped_data + + +def read_ppm(file_path: str) -> Tuple[Tensor, POSSIBLE_BITDEPTH]: + """Read a `PPM file `_, + and return a torch tensor [1, 3, H, W] containing the data. + The returned tensor is in [0., 1.] so its bitdepth is also returned. + + .. attention:: + + We don't filter out comments inside PPM files... + + Args: + file_path: Path of the ppm file to read. + + Returns: + Image data [1, 3, H, W] in [0., 1.] and its bitdepth. + """ + assert os.path.isfile(file_path), f"No file found at {file_path}." + + data = open(file_path, "rb").read() + magic_number = data[:2].decode("utf-8") + data = data[2:] + assert magic_number == "P6", ( + "Invalid file format. PPM file should start with P6. " f"Found {magic_number}." + ) + + # Parse the header + width, data = _read_int_until_blank(_skip_one_byte(data)) + height, data = _read_int_until_blank(_skip_one_byte(data)) + max_val, data = _read_int_until_blank(_skip_one_byte(data)) + data = _skip_one_byte(data) + + n_bytes_per_val = 1 if max_val <= 255 else 2 + bitdepth = int(math.log2(max_val + 1)) + + raw_value = torch.from_numpy( + np.frombuffer( + data, + count=3 * width * height, + dtype=np.uint8 if n_bytes_per_val == 1 else np.uint16, + ) + ) + + # Re-arrange the value from R1 B1 G1 R2 B2 G2 to an usual [B, C, H, W] array + img = torch.empty( + (1, 3, height, width), + dtype=torch.float32, + ) + + for i in range(3): + img[:, i, :, :] = raw_value[i::3].view(1, height, width) + + # In a PPM file 2-byte value (e.g. 257) is represented as + # 1111 1111 0000 0010 + # \______/ \______/ + # MSB LSB + # We want to invert these two bytes here to have an usual binary value + # 0000 0010 1111 1111 + if n_bytes_per_val == 2: + img = _16bits_byte_swap(img) + + # Normalize in [0. 1.] + img = img / (2**bitdepth - 1) + + return img, bitdepth + + +@torch.no_grad() +def write_ppm( + data: Tensor, bitdepth: POSSIBLE_BITDEPTH, file_path: str, norm: bool = True +) -> None: + """Save an image x into a PPM file. + + Args: + data: Image to be saved + bitdepth: Bitdepth, should be in + ``[8, 9, 10, 11, 12, 13, 14, 15, 16]``. + file_path: Where to save the PPM files + bitdepth: Bitdepth of the file. Defaults to 8. + norm: True to multiply the data by 2 ** bitdepth - 1. Defaults to True. + """ + # Remove all first dimensions of size 1 + c, h, w = data.size()[-3:] + data = data.view((c, h, w)) + + max_val = 2**data.bitdepth - 1 + n_bytes_per_val = 1 if max_val <= 255 else 2 + header = f"P6\n{w} {h}\n{max_val}\n" + + if norm: + data = torch.round(data * (2**bitdepth - 1)) + + # In a PPM file 2-byte value (e.g. 257) is represented as + # 1111 1111 0000 0010 + # \______/ \______/ + # MSB LSB + # We want to invert these two bytes here to have an usual binary value + # 0000 0010 1111 1111 + if n_bytes_per_val == 2: + data = _16bits_byte_swap(data) + + # Format data as expected by the PPM file. + flat_data = torch.empty( + (c * h * w), dtype=torch.uint8 if max_val <= 255 else torch.uint16 + ) + for i in range(c): + flat_data[i::3] = data[i, :, :].flatten() + + # Write once the header as a string then the data as binary bytes + with open(file_path, "w") as f_out: + f_out.write(header) + with open(file_path, "ab") as f_out: + f_out.write(np.memmap.tobytes(flat_data.numpy())) diff --git a/coolchic/enc/io/format/yuv.py b/coolchic/enc/io/format/yuv.py new file mode 100644 index 00000000..15d17ad7 --- /dev/null +++ b/coolchic/enc/io/format/yuv.py @@ -0,0 +1,311 @@ +# Software Name: Cool-Chic +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange +# SPDX-License-Identifier: BSD 3-Clause "New" +# +# This software is distributed under the BSD-3-Clause license. +# +# Authors: see CONTRIBUTORS.md + + +import os +from typing import TypedDict, Union + +import numpy as np +import torch +import torch.nn.functional as F +from enc.io.format.data_type import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH +from enc.utils.misc import POSSIBLE_DEVICE +from torch import Tensor + + +class DictTensorYUV(TypedDict): + """``TypedDict`` representing a YUV420 frame.. + + .. hint:: + + ``torch.jit`` requires I/O of modules to be either ``Tensor``, ``List`` + or ``Dict``. So we don't use a python dataclass here and rely on + ``TypedDict`` instead. + + Args: + y (Tensor): Tensor with shape :math:`([B, 1, H, W])`. + u (Tensor): Tensor with shape :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`. + v (Tensor): Tensor with shape :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`. + """ + + y: Tensor + u: Tensor + v: Tensor + + +def read_yuv( + file_path: str, + frame_idx: int, + frame_data_type: FRAME_DATA_TYPE, + bit_depth: POSSIBLE_BITDEPTH, +) -> Union[DictTensorYUV, Tensor]: + """From a file_path /a/b/c.yuv, read the desired frame_index + and return a dictionary of tensor containing the YUV values: + + .. code:: none + + { + 'Y': [1, 1, H, W], + 'U': [1, 1, H / S, W / S], + 'V': [1, 1, H / S, W / S], + } + + ``S`` is either 1 (444 sampling) or 2 (420). The YUV values are in [0., 1.] + + Args: + file_path: Absolute path of the video to load + frame_idx: Index of the frame to load, starting at 0. + frame_data_type: chroma sampling (420,444) + bit depth: Number of bits per component (8 or 10 bits). + + Returns: + For 420, return a dict of tensors with YUV values of shape [1, 1, H, W]. + For 444 return a [1, 3, H, W] tensor. + """ + + # Parse height and width from the file_path + w, h = [ + int(tmp_str) + for tmp_str in os.path.basename(file_path).split(".")[0].split("_")[1].split("x") + ] + + if frame_data_type == "yuv420": + w_uv, h_uv = [int(x / 2) for x in [w, h]] + else: + w_uv, h_uv = w, h + + # Switch between 8 bit file and 10 bit + byte_per_value = 1 if bit_depth == 8 else 2 + + # We only handle YUV420 for now + n_val_y = h * w + n_val_uv = h_uv * w_uv + n_val_per_frame = n_val_y + 2 * n_val_uv + + n_bytes_y = n_val_y * byte_per_value + n_bytes_uv = n_val_uv * byte_per_value + n_bytes_per_frame = n_bytes_y + 2 * n_bytes_uv + + # Read the required frame and put it in a 1d tensor + raw_video = torch.tensor( + np.memmap( + file_path, + mode="r", + shape=n_val_per_frame, + offset=n_bytes_per_frame * frame_idx, + dtype=np.uint16 if bit_depth == 10 else np.uint8, + ).astype(np.float32) + ) + + # Read the different values from raw video and store them inside y, u and v + ptr = 0 + y = raw_video[ptr : ptr + n_val_y].view(1, 1, h, w) + ptr += n_val_y + u = raw_video[ptr : ptr + n_val_uv].view(1, 1, h_uv, w_uv) + ptr += n_val_uv + v = raw_video[ptr : ptr + n_val_uv].view(1, 1, h_uv, w_uv) + + # PyTorch expect data in [0., 1.]; normalize by either 255 or 1023 + norm_factor = 2**bit_depth - 1 + + if frame_data_type == "yuv420": + video = DictTensorYUV(y=y / norm_factor, u=u / norm_factor, v=v / norm_factor) + else: + video = torch.cat([y, u, v], dim=1) / norm_factor + + return video + + +@torch.no_grad() +def write_yuv( + data: Union[Tensor, DictTensorYUV], + bitdepth: POSSIBLE_BITDEPTH, + frame_data_type: FRAME_DATA_TYPE, + file_path: str, + norm: bool = True, +) -> None: + """Store a YUV frame as a YUV file named file_path. They are appended to the + end of the file_path If norm is True: the video data is expected to be in + [0., 1.] so we multiply it by 255. Otherwise we let it as is. + + Args: + data: Data to save + bitdepth: Bitdepth, should be in``[8, 9, 10, 11, 12, 13, 14, 15, 16]``. + frame_data_type: Data type, either ``"yuv420"`` or ``"yuv444"``. + file_path: Absolute path of the file where the YUV is saved. + norm: True to multiply the data by 2 ** bitdepth - 1. + Defaults to True. + """ + assert frame_data_type in ["yuv420", "yuv444"], ( + f"Found incorrect datatype in write_yuv() function: {frame_data_type}. " + 'Data type should be "yuv420" or "yuv444".' + ) + + if frame_data_type == "yuv420": + raw_data = torch.cat([channels.flatten() for _, channels in data.items()]) + else: + raw_data = data.flatten() + + if norm: + raw_data = raw_data * (2**bitdepth - 1) + + dtype = np.uint16 if bitdepth == 10 else np.uint8 + + # Round the values and cast them to uint8 or uint16 tensor + raw_data = torch.round(raw_data).cpu().numpy().astype(dtype) + + # # We need to add a p avec yuv444 otherwise YUView thinks its "YUV444 8-bit packed" + # file_path = ( + # f"{file_path}_{w}x{h}_{DUMMY_FRAMERATE}fps_{frame_data_type}p_{bitdepth}b.yuv" + # ) + + # Write this to the desired file_path + np.memmap.tofile(raw_data, file_path) + + +def rgb2yuv(rgb: Tensor) -> Tensor: + """Convert a 4D RGB tensor [1, 3, H, W] into a 4D YUV444 tensor [1, 3, H, W]. + The RGB and YUV values are in the range [0, 255] + + Args: + rgb: 4D RGB tensor to convert in [0. 255.] + + Returns: + The resulting YUV444 tensor in [0. 255.] + """ + assert ( + len(rgb.size()) == 4 + ), f"rgb2yuv input must be a 4D tensor [B, 3, H, W]. Data size: {rgb.size()}" + assert ( + rgb.size()[1] == 3 + ), f"rgb2yuv input must have 3 channels. Data size: {rgb.size()}" + + # Split the [1, 3, H, W] into 3 [1, 1, H, W] tensors + r, g, b = rgb.split(1, dim=1) + + # Compute the different channels + y = torch.round(0.299 * r + 0.587 * g + 0.114 * b) + u = torch.round(-0.1687 * r - 0.3313 * g + 0.5 * b + +128) + v = torch.round(0.5 * r - 0.4187 * g - 0.0813 * b + 128) + + # Concatenate them into the resulting yuv 4D tensor. + yuv = torch.cat((y, u, v), dim=1) + return yuv + + +def yuv2rgb(yuv: Tensor): + """Convert a 4D YUV tensor [1, 3, H, W] into a 4D RGB tensor [1, 3, H, W]. + The RGB and YUV values are in the range [0, 255] + + Args: + rgb: 4D YUV444 tensor to convert in [0. 255.] + + Returns: + The resulting RGB tensor in [0. 255.] + """ + assert ( + len(yuv.size()) == 4 + ), f"yuv2rgb input must be a 4D tensor [B, 3, H, W]. Data size: {yuv.size()}" + assert ( + yuv.size()[1] == 3 + ), f"yuv2rgb input must have 3 channels. Data size: {yuv.size()}" + + y, u, v = yuv.split(1, dim=1) + r = ( + 1.0 * y + + -0.000007154783816076815 * u + + 1.4019975662231445 * v + - 179.45477266423404 + ) + g = 1.0 * y + -0.3441331386566162 * u + -0.7141380310058594 * v + 135.45870971679688 + b = ( + 1.0 * y + + 1.7720025777816772 * u + + 0.00001542569043522235 * v + - 226.8183044444304 + ) + rgb = torch.cat((r, g, b), dim=1) + return rgb + + +def yuv_dict_clamp(yuv: DictTensorYUV, min_val: float, max_val: float) -> DictTensorYUV: + """Clamp the y, u & v tensor. + + Args: + yuv: The data to clamp + min_val: Minimum value for the clamp + max_val: Maximum value for the clamp + + Returns: + The clamped data + + """ + clamped_yuv = DictTensorYUV( + y=yuv.get("y").clamp(min_val, max_val), + u=yuv.get("u").clamp(min_val, max_val), + v=yuv.get("v").clamp(min_val, max_val), + ) + return clamped_yuv + + +def yuv_dict_to_device(yuv: DictTensorYUV, device: POSSIBLE_DEVICE) -> DictTensorYUV: + """Send a ``DictTensor`` to a device. + + Args: + yuv: Data to be sent to a device. + device: The requested device + + Returns: + Data on the appropriate device. + """ + return DictTensorYUV( + y=yuv.get("y").to(device), u=yuv.get("u").to(device), v=yuv.get("v").to(device) + ) + + +def convert_444_to_420(yuv444: Tensor) -> DictTensorYUV: + """From a 4D YUV 444 tensor :math:`(B, 3, H, W)`, return a + ``DictTensorYUV``. The U and V tensors are down sampled using a nearest + neighbor downsampling. + + Args: + yuv444: YUV444 data :math:`(B, 3, H, W)` + + Returns: + YUV420 dictionary of 4D tensors + """ + assert yuv444.dim() == 4, f"Number of dimension should be 5, found {yuv444.dim()}" + + b, c, h, w = yuv444.size() + assert c == 3, f"Number of channel should be 3, found {c}" + + # No need to downsample y channel but it should remain a 5D tensor + y = yuv444[:, 0, :, :].view(b, 1, h, w) + + # Downsample U and V channels together + uv = F.interpolate(yuv444[:, 1:3, :, :], scale_factor=(0.5, 0.5), mode="nearest") + u, v = uv.split(1, dim=1) + + yuv420 = DictTensorYUV(y=y, u=u, v=v) + return yuv420 + + +def convert_420_to_444(yuv420: DictTensorYUV) -> Tensor: + """Convert a DictTensorYUV to a 4D tensor:math:`(B, 3, H, W)`. + The U and V tensors are up sampled using a nearest neighbor upsampling + + Args: + yuv420: YUV420 dictionary of 4D tensor + + Returns: + YUV444 Tensor :math:`(B, 3, H, W)` + """ + u = F.interpolate(yuv420.get("u"), scale_factor=(2, 2)) + v = F.interpolate(yuv420.get("v"), scale_factor=(2, 2)) + yuv444 = torch.cat((yuv420.get("y"), u, v), dim=1) + return yuv444 diff --git a/coolchic/enc/io/io.py b/coolchic/enc/io/io.py new file mode 100644 index 00000000..c46ff0b0 --- /dev/null +++ b/coolchic/enc/io/io.py @@ -0,0 +1,39 @@ +from enc.utils.codingstructure import FrameData +from enc.io.format.data_type import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH +from enc.io.format.ppm import read_ppm +from enc.io.format.yuv import read_yuv +from enc.io.format.png import read_png + + +def load_frame_data_from_file(file_path: str, idx_display_order: int) -> FrameData: + """Load the idx_display_order-th frame from a .yuv file or .png file. For the latter, + idx_display_order must be equal to 0 as there is only one frame in a png. + + Args: + file_path (str): Absolute path of the file from which the frame is loaded. + idx_display_order (int): Index of the frame in display order + + Returns: + FrameData: The loaded frame, wrapped as a FrameData object. + """ + POSSIBLE_EXT = [".yuv", ".png", ".ppm"] + assert file_path[-4:] in POSSIBLE_EXT, ( + "The function load_frame_data_from_file() expects a file ending with " + f"{POSSIBLE_EXT}. Found {file_path}" + ) + + if file_path.endswith(".yuv"): + # ! We only consider yuv420 and 444 planar + bitdepth: POSSIBLE_BITDEPTH = 8 if "_8b" in file_path else 10 + frame_data_type: FRAME_DATA_TYPE = "yuv420" if "420" in file_path else "yuv444" + data = read_yuv(file_path, idx_display_order, frame_data_type, bitdepth) + + elif file_path.endswith(".png"): + frame_data_type: FRAME_DATA_TYPE = "rgb" + data, bitdepth = read_png(file_path) + + elif file_path.endswith(".ppm"): + frame_data_type: FRAME_DATA_TYPE = "rgb" + data, bitdepth = read_ppm(file_path) + + return FrameData(bitdepth, frame_data_type, data) diff --git a/coolchic/enc/utils/presets.py b/coolchic/enc/training/presets.py similarity index 99% rename from coolchic/enc/utils/presets.py rename to coolchic/enc/training/presets.py index 653f3ea5..8e818eb5 100644 --- a/coolchic/enc/utils/presets.py +++ b/coolchic/enc/training/presets.py @@ -398,3 +398,4 @@ def __init__(self, start_lr: float = 1e-2, n_itr_per_phase: int = 100000): "c3x": PresetC3x, "debug": PresetDebug, } + diff --git a/coolchic/enc/training/quantizemodel.py b/coolchic/enc/training/quantizemodel.py index a009cc99..6aa58a9e 100644 --- a/coolchic/enc/training/quantizemodel.py +++ b/coolchic/enc/training/quantizemodel.py @@ -51,9 +51,9 @@ def _quantize_parameters( sent_param = torch.round(v / current_q_step) if sent_param.abs().max() > MAX_AC_MAX_VAL: - print( - f"Sent param {k} exceed MAX_AC_MAX_VAL! Q step {current_q_step} too small." - ) + #print( + # f"Sent param {k} exceed MAX_AC_MAX_VAL! Q step {current_q_step} too small." + #) return None q_param[k] = sent_param * current_q_step @@ -187,7 +187,7 @@ def quantize_model( # to obtain the sent latent. current_sent_param = (parameter_value / current_q_step.get(weight_or_bias)).view(-1) - if parameter_name.endswith(weight_or_bias): + if weight_or_bias in parameter_name: sent_param.append(current_sent_param) # Integer, sent parameters diff --git a/coolchic/enc/training/train.py b/coolchic/enc/training/train.py index ce1b0345..5c06b6d7 100644 --- a/coolchic/enc/training/train.py +++ b/coolchic/enc/training/train.py @@ -23,7 +23,7 @@ from enc.training.loss import loss_function from enc.training.test import test from enc.utils.codingstructure import Frame -from enc.utils.presets import MODULE_TO_OPTIMIZE +from enc.training.presets import MODULE_TO_OPTIMIZE # Custom scheduling function for the soft rounding temperature and the noise parameter @@ -306,7 +306,7 @@ def train( "patience": (patience - cnt + cnt_record) // frequency_validation, "q_type": f"{quantizer_type:12s}", "sr_temp": f"{cur_softround_temperature:.5f}", - "n_type": f"{quantizer_noise_type:20s}", + "n_type": f"{quantizer_noise_type:12s}", "noise": f"{cur_noise_parameter:.2f}", "record": log_new_record, } diff --git a/coolchic/enc/utils/codingstructure.py b/coolchic/enc/utils/codingstructure.py index 81d983a6..d649d923 100644 --- a/coolchic/enc/utils/codingstructure.py +++ b/coolchic/enc/utils/codingstructure.py @@ -10,14 +10,16 @@ import math from dataclasses import dataclass, field -from typing import List, Literal, Optional, Tuple, TypedDict, Union +from typing import List, Literal, Optional, Tuple, Union -import torch -import torch.nn.functional as F from torch import Tensor from enc.utils.misc import POSSIBLE_DEVICE +from enc.io.format.data_type import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH +from enc.io.format.yuv import DictTensorYUV, convert_420_to_444, yuv_dict_to_device + + # The different frame types: # - I frames have no reference (intra) # - P frames have 1 single (past) reference @@ -44,90 +46,6 @@ # between two I-frames. First GOP P-period is 1, while second P period is 4 # which is the distance of the P-frame prediction. # -FRAME_DATA_TYPE = Literal["rgb", "yuv420", "yuv444"] -POSSIBLE_BITDEPTH = Literal[8, 10] - - -class DictTensorYUV(TypedDict): - """``TypedDict`` representing a YUV420 frame.. - - .. hint:: - - ``torch.jit`` requires I/O of modules to be either ``Tensor``, ``List`` - or ``Dict``. So we don't use a python dataclass here and rely on - ``TypedDict`` instead. - - Args: - y (Tensor): :math:`([B, 1, H, W])`. - u (Tensor): :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`. - v (Tensor): :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`. - """ - - y: Tensor - u: Tensor - v: Tensor - - -def yuv_dict_to_device(yuv: DictTensorYUV, device: POSSIBLE_DEVICE) -> DictTensorYUV: - """Send a ``DictTensor`` to a device. - - Args: - yuv: Data to be sent to a device. - device: The requested device - - Returns: - Data on the appropriate device. - """ - return DictTensorYUV( - y=yuv.get("y").to(device), u=yuv.get("u").to(device), v=yuv.get("v").to(device) - ) - - -# ============================== YUV upsampling ============================= # -def convert_444_to_420(yuv444: Tensor) -> DictTensorYUV: - """From a 4D YUV 444 tensor :math:`(B, 3, H, W)`, return a - ``DictTensorYUV``. The U and V tensors are down sampled using a nearest - neighbor downsampling. - - Args: - yuv444: YUV444 data :math:`(B, 3, H, W)` - - Returns: - YUV420 dictionary of 4D tensors - """ - assert yuv444.dim() == 4, f"Number of dimension should be 5, found {yuv444.dim()}" - - b, c, h, w = yuv444.size() - assert c == 3, f"Number of channel should be 3, found {c}" - - # No need to downsample y channel but it should remain a 5D tensor - y = yuv444[:, 0, :, :].view(b, 1, h, w) - - # Downsample U and V channels together - uv = F.interpolate(yuv444[:, 1:3, :, :], scale_factor=(0.5, 0.5), mode="nearest") - u, v = uv.split(1, dim=1) - - yuv420 = DictTensorYUV(y=y, u=u, v=v) - return yuv420 - - -def convert_420_to_444(yuv420: DictTensorYUV) -> Tensor: - """Convert a DictTensorYUV to a 4D tensor:math:`(B, 3, H, W)`. - The U and V tensors are up sampled using a nearest neighbor upsampling - - Args: - yuv420: YUV420 dictionary of 4D tensor - - Returns: - YUV444 Tensor :math:`(B, 3, H, W)` - """ - u = F.interpolate(yuv420.get("u"), scale_factor=(2, 2)) - v = F.interpolate(yuv420.get("v"), scale_factor=(2, 2)) - yuv444 = torch.cat((yuv420.get("y"), u, v), dim=1) - return yuv444 - - -# ============================== YUV upsampling ============================= # @dataclass @@ -136,7 +54,8 @@ class FrameData: a few additional information about its size, bitdepth of color space. Args: - bitdepth (POSSIBLE_BITDEPTH): Bitdepth, either ``"8"`` or ``"10"``. + bitdepth (POSSIBLE_BITDEPTH): Bitdepth, should be in + ``[8, 9, 10, 11, 12, 13, 14, 15, 16]``. frame_data_type (FRAME_DATA_TYPE): Data type, either ``"rgb"``, ``"yuv420"``, ``"yuv444"``. data (Union[Tensor, DictTensorYUV]): The actual RGB or YUV data @@ -172,6 +91,15 @@ def to_device(self, device: POSSIBLE_DEVICE) -> None: elif self.frame_data_type == "yuv420": self.data = yuv_dict_to_device(self.data, device) + def to_string(self) -> str: + """Pretty string describing the frame data.""" + s = "Frame data information:\n" + s += "-----------------------\n" + s += f"{'Resolution (H, W)':<26}: {self.img_size[0]}, {self.img_size[1]}\n" + s += f"{'Bitdepth':<26}: {self.bitdepth}\n" + s += f"{'Data type':<26}: {self.frame_data_type}" + + return s @dataclass class Frame: diff --git a/coolchic/enc/utils/manager.py b/coolchic/enc/utils/manager.py index 49daaac9..0190d9fc 100644 --- a/coolchic/enc/utils/manager.py +++ b/coolchic/enc/utils/manager.py @@ -7,7 +7,7 @@ # Authors: see CONTRIBUTORS.md from dataclasses import dataclass, field, fields -from enc.utils.presets import AVAILABLE_PRESETS, Preset +from enc.training.presets import AVAILABLE_PRESETS, Preset @dataclass diff --git a/coolchic/enc/utils/misc.py b/coolchic/enc/utils/misc.py index f40fd555..93c536aa 100644 --- a/coolchic/enc/utils/misc.py +++ b/coolchic/enc/utils/misc.py @@ -154,13 +154,13 @@ def get_q_step_from_parameter_name( Optional[float]: The quantization step associated to the parameter. Return None if nothing is found. """ - if parameter_name.endswith(".weight"): + if ".weight" in parameter_name: current_q_step = q_step.get("weight") - elif parameter_name.endswith(".bias"): + elif ".bias" in parameter_name: current_q_step = q_step.get("bias") else: print( - 'Parameter name should end with ".weight" or ".bias" ' + 'Parameter name should include ".weight" or ".bias" ' f"Found: {parameter_name}" ) current_q_step = None @@ -195,13 +195,13 @@ def measure_expgolomb_rate( # to obtain the sent latent. current_sent_param = (parameter_value / current_q_step).view(-1) - if parameter_name.endswith(".weight"): + if ".weight" in parameter_name: sent_param["weight"].append(current_sent_param) - elif parameter_name.endswith(".bias"): + elif ".bias" in parameter_name: sent_param["bias"].append(current_sent_param) else: print( - 'Parameter name should end with ".weight" or ".bias" ' + 'Parameter name should include ".weight" or ".bias" ' f"Found: {parameter_name}" ) return rate_param diff --git a/coolchic/enc/utils/parsecli.py b/coolchic/enc/utils/parsecli.py new file mode 100644 index 00000000..fd072121 --- /dev/null +++ b/coolchic/enc/utils/parsecli.py @@ -0,0 +1,174 @@ +# Software Name: Cool-Chic +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange +# SPDX-License-Identifier: BSD 3-Clause "New" +# +# This software is distributed under the BSD-3-Clause license. +# +# Authors: see CONTRIBUTORS.md + + +import argparse +import os +from typing import Any, Dict, List + + +# ----- Arguments related to Cool-chic parameters +def _parse_synthesis_layers(layers_synthesis: str) -> List[str]: + """The layers of the synthesis are presented in as a coma-separated string. + This simply splits up the different substrings and return them. + + Args: + layers_synthesis (str): Command line argument for the synthesis. + + Returns: + List[str]: List of string where the i-th element described the i-th + synthesis layer + """ + parsed_layer_synth = [x for x in layers_synthesis.split(",") if x != ""] + + assert parsed_layer_synth, ( + "Synthesis should have at least one layer, found nothing. \n" + f"--layers_synthesis={layers_synthesis} does not work!\n" + "Try something like 32-1-linear-relu,X-1-linear-none," + "X-3-residual-relu,X-3-residual-none" + ) + + return parsed_layer_synth + + +def _parse_arm_archi(arm: str) -> Dict[str, int]: + """The arm is described as ,. + Split up this string to return the value as a dict. + + Args: + arm (str): Command line argument for the ARM. + + Returns: + Dict[str, int]: The ARM architecture + """ + assert len(arm.split(",")) == 2, f"--arm format should be X,Y." f" Found {arm}" + + dim_arm, n_hidden_layers_arm = [int(x) for x in arm.split(",")] + arm_param = {"dim_arm": dim_arm, "n_hidden_layers_arm": n_hidden_layers_arm} + return arm_param + + +def _parse_n_ft_per_res(n_ft_per_res: str) -> List[int]: + """The number of feature per resolution is a coma-separated string. + This simply splits up the different substrings and return them. + + Args: + n_ft_per_res (str): Something like "1,1,1,1,1,1,1" for 7 latent grids + with different resolution and 1 feature each. + + Returns: + List[int]: The i-th element is the number of features for the i-th + latent, i.e. the latent of a resolution (H / 2^i, W / 2^i). + """ + + n_ft_per_res = [int(x) for x in n_ft_per_res.split(",") if x != ""] + assert set(n_ft_per_res) == { + 1 + }, f"--n_ft_per_res should only contains 1. Found {n_ft_per_res}" + return n_ft_per_res + + +def get_coolchic_param_from_args(args: argparse.Namespace) -> Dict[str, Any]: + + layers_synthesis = _parse_synthesis_layers(getattr(args, "layers_synthesis")) + n_ft_per_res = _parse_n_ft_per_res(getattr(args, "n_ft_per_res")) + + coolchic_param = { + "layers_synthesis": layers_synthesis, + "n_ft_per_res": n_ft_per_res, + "ups_k_size": getattr(args, "ups_k_size"), + "ups_preconcat_k_size": getattr(args, "ups_preconcat_k_size"), + } + + # Add ARM parameters + coolchic_param.update(_parse_arm_archi(getattr(args, "arm"))) + + return coolchic_param + + +# ----- Arguments related to the coding structure +def _is_image(file_path: str) -> bool: + """Return True is file extension is an image extension ie JPEG, PNG or PPM. + + Args: + file_path (str): Path of the file. + + Returns: + bool: True is file is an "image". + """ + + possible_file_extension = ["png", "jpeg", "jpg", "ppm"] + + for ext in possible_file_extension: + if file_path.endswith(f".{ext}"): + return True + + if file_path.endswith(f".{ext.capitalize()}"): + return True + + return False + +def get_coding_structure_from_args(args: argparse.Namespace) -> Dict[str, Any]: + """Perform some check on the argparse object used to collect the command + line parameters. Return a dictionary ready to be plugged into the + ``CodingStructure`` constructor. + + Args: + args (argparse.Namespace): Command-line argument parser. + + Returns: + Dict[str, Any]: Dictionary ready to be plugged into the ``CodingStructure`` + constructor. + """ + intra_period = args.intra_period + p_period = args.p_period + + assert intra_period >= 0 and intra_period <= 255, ( + f"Intra period should be in [0, 255]. Found {intra_period}" + ) + + assert p_period >= 0 and p_period <= 255, ( + f"P period should be in [0, 255]. Found {p_period}" + ) + + if _is_image(args.input): + assert intra_period == 0 and p_period == 0, ( + f"Encoding a PNG, JPEG or PPM image {args.input} must be done with" + "intra_period = 0 and p_period = 0. Found intra_period = " + f"{args.intra_period} and p_period = {args.p_period}" + ) + + coding_structure_config = { + "intra_period": intra_period, + "p_period": p_period, + "seq_name": os.path.basename(args.input).split(".")[0], + } + return coding_structure_config + + +# ----- Arguments related to the frame encoder manager i.e. training preset etc. +def get_manager_from_args(args: argparse.Namespace) -> Dict[str, Any]: + """Perform some check on the argparse object used to collect the command + line parameters. Return a dictionary ready to be plugged into the + ``FrameEncoderManager`` constructor. + + Args: + args (argparse.Namespace): Command-line argument parser. + + Returns: + Dict[str, Any]: Dictionary ready to be plugged into the + ``FrameEncoderManager`` constructor. + """ + frame_encoder_manager = { + "preset_name": args.recipe, + "start_lr": args.start_lr, + "lmbda": args.lmbda, + "n_loops": args.n_train_loops, + "n_itr": args.n_itr, + } + return frame_encoder_manager diff --git a/coolchic/enc/utils/yuv.py b/coolchic/enc/utils/yuv.py deleted file mode 100644 index ba29dc51..00000000 --- a/coolchic/enc/utils/yuv.py +++ /dev/null @@ -1,258 +0,0 @@ -# Software Name: Cool-Chic -# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange -# SPDX-License-Identifier: BSD 3-Clause "New" -# -# This software is distributed under the BSD-3-Clause license. -# -# Authors: see CONTRIBUTORS.md - - -import os - -import numpy as np -import torch -from einops import rearrange -from enc.utils.codingstructure import ( - FRAME_DATA_TYPE, - POSSIBLE_BITDEPTH, - DictTensorYUV, - FrameData, -) -from PIL import Image -from torch import Tensor -from torchvision.transforms.functional import to_tensor - - -def yuv_dict_clamp(yuv: DictTensorYUV, min_val: float, max_val: float) -> DictTensorYUV: - """Clamp the y, u & v tensor. - - Args: - yuv (DictTensorYUV): The data to clamp - min_val (float): Minimum value for the clamp - max_val (float): Maximum value for the clamp - - Returns: - DictTensorYUV: The clamped data - - """ - clamped_yuv = DictTensorYUV( - y=yuv.get("y").clamp(min_val, max_val), - u=yuv.get("u").clamp(min_val, max_val), - v=yuv.get("v").clamp(min_val, max_val), - ) - return clamped_yuv - - -def load_frame_data_from_file(filename: str, idx_display_order: int) -> FrameData: - """Load the idx_display_order-th frame from a .yuv file or .png file. For the latter, - idx_display_order must be equal to 0 as there is only one frame in a png. - - Args: - filename (str): Absolute path of the file from which the frame is loaded. - idx_display_order (int): Index of the frame in display order - - Returns: - FrameData: The loaded frame, wrapped as a FrameData object. - """ - - if filename.endswith(".yuv"): - # ! We only consider yuv420 and 444 planar - bitdepth: POSSIBLE_BITDEPTH = 8 if "_8b" in filename else 10 - frame_data_type: FRAME_DATA_TYPE = "yuv420" if "420" in filename else "yuv444" - data = read_yuv(filename, idx_display_order, frame_data_type, bitdepth) - - elif filename.endswith(".png"): - bitdepth: POSSIBLE_BITDEPTH = 8 - frame_data_type: FRAME_DATA_TYPE = "rgb" - data = to_tensor(Image.open(filename)) - data = rearrange(data, "c h w -> 1 c h w") - - return FrameData(bitdepth, frame_data_type, data) - - -def read_yuv(filename: str, frame_idx: int, frame_data_type: FRAME_DATA_TYPE, bit_depth: POSSIBLE_BITDEPTH) -> DictTensorYUV: - """From a filename /a/b/c.yuv, read the desired frame_index - and return a dictionary of tensor containing the YUV values: - { - 'Y': [1, 1, H, W], - 'U': [1, 1, H / S, W / S], - 'V': [1, 1, H / S, W / S], - } - S is either 1 (444 sampling) or 2 (420) - The YUV values are in [0., 1.] - - Args: - filename (str): Absolute path of the video to load - frame_idx (int): Index of the frame to load, starting at 0. - bit depth (int):number of bits per component (8 or 10 bits). - frame_data_type chroma sampling (420,444): - - Returns: - DictTensorYUV: The YUV values (see format above) for 420. - pytorch tensor for 444 sampling format (consistent with rgb representation) - """ - - # Parse height and width from the filename - w, h = [ - int(tmp_str) - for tmp_str in os.path.basename(filename).split(".")[0].split("_")[1].split("x") - ] - - if frame_data_type == "yuv420": - w_uv, h_uv = [int(x / 2) for x in [w, h]] - else: - w_uv, h_uv = w, h - - # Switch between 8 bit file and 10 bit - byte_per_value = 1 if bit_depth == 8 else 2 - - # We only handle YUV420 for now - n_val_y = h * w - n_val_uv = h_uv * w_uv - n_val_per_frame = n_val_y + 2 * n_val_uv - - n_bytes_y = n_val_y * byte_per_value - n_bytes_uv = n_val_uv * byte_per_value - n_bytes_per_frame = n_bytes_y + 2 * n_bytes_uv - - # Read the required frame and put it in a 1d tensor - raw_video = torch.tensor( - np.memmap( - filename, - mode="r", - shape=n_val_per_frame, - offset=n_bytes_per_frame * frame_idx, - dtype=np.uint16 if bit_depth == 10 else np.uint8, - ).astype(np.float32) - ) - - # Read the different values from raw video and store them inside y, u and v - ptr = 0 - y = raw_video[ptr : ptr + n_val_y].view(1, 1, h, w) - ptr += n_val_y - u = raw_video[ptr : ptr + n_val_uv].view(1, 1, h_uv, w_uv) - ptr += n_val_uv - v = raw_video[ptr : ptr + n_val_uv].view(1, 1, h_uv, w_uv) - - # PyTorch expect data in [0., 1.]; normalize by either 255 or 1023 - norm_factor = 2**bit_depth - 1 - - if frame_data_type == "yuv420": - video = DictTensorYUV(y=y / norm_factor, u=u / norm_factor, v=v / norm_factor) - else: - video = torch.cat([y, u, v], dim=1) / norm_factor - - return video - - -def write_yuv(data: FrameData, filename: str, norm: bool = True) -> None: - """Store a YUV frame as a YUV file named filename. All parameters of the YUV - file (resolution, chroma subsampling, bitdepth) are contained in the FrameData - object alongside the actual data. They are appended to the end of the filename - If norm is True: the video data is expected to be in [0., 1.] so we - multiply it by 255. Otherwise we let it as is. - - Args: - data (FrameData): Data to save - filename (str): Absolute path of the file where the YUV is saved. - norm (bool): True to multiply the data by 2 ** bitdepth - 1. - """ - assert data.frame_data_type in ["yuv420", "yuv444"], ( - "Found incorrect datatype in " - f'write_yuv() function: {data.frame_data_type}. Data type should be "yuv420" or "yuv444".' - ) - - # Append .yuv at the end of the file to make sure it is present - if not (filename[-4:] == ".yuv"): - filename += ".yuv" - # From here, there is no .yuv at the end of filename - filename = filename[:-4] - - # Append spatial dimension to the filename, dummy framerate - # and bit depth - DUMMY_FRAMERATE = 1 - h, w = data.img_size - # We need to add a p avec yuv444 otherwise YUView thinks its "YUV444 8-bit packed" - filename = f"{filename}_{w}x{h}_{DUMMY_FRAMERATE}fps_{data.frame_data_type}p_{data.bitdepth}b.yuv" - - # Concatenate **all** channels into a 2D tensor [1.5 * H * W] - if data.frame_data_type == "yuv420": - raw_data = torch.cat([channels.flatten() for _, channels in data.data.items()]) - elif data.frame_data_type == "yuv444": - raw_data = data.data.flatten() - - if norm: - raw_data = raw_data * (2**data.bitdepth - 1) - - dtype = np.uint16 if data.bitdepth == 10 else np.uint8 - - # Round the values and cast them to uint8 or uint16 tensor - raw_data = torch.round(raw_data).cpu().numpy().astype(dtype) - - # Write this to the desired filename - np.memmap.tofile(raw_data, filename) - - -def rgb2yuv(rgb: Tensor) -> Tensor: - """Convert a 4D RGB tensor [1, 3, H, W] into a 4D YUV444 tensor [1, 3, H, W]. - The RGB and YUV values are in the range [0, 255] - - Args: - rgb (Tensor): 4D RGB tensor to convert in [0. 255.] - - Returns: - Tensor: the resulting YUV444 tensor in [0. 255.] - """ - assert ( - len(rgb.size()) == 4 - ), f"rgb2yuv input must be a 4D tensor [B, 3, H, W]. Data size: {rgb.size()}" - assert ( - rgb.size()[1] == 3 - ), f"rgb2yuv input must have 3 channels. Data size: {rgb.size()}" - - # Split the [1, 3, H, W] into 3 [1, 1, H, W] tensors - r, g, b = rgb.split(1, dim=1) - - # Compute the different channels - y = torch.round(0.299 * r + 0.587 * g + 0.114 * b) - u = torch.round(-0.1687 * r - 0.3313 * g + 0.5 * b + +128) - v = torch.round(0.5 * r - 0.4187 * g - 0.0813 * b + 128) - - # Concatenate them into the resulting yuv 4D tensor. - yuv = torch.cat((y, u, v), dim=1) - return yuv - - -def yuv2rgb(yuv: Tensor): - """Convert a 4D YUV tensor [1, 3, H, W] into a 4D RGB tensor [1, 3, H, W]. - The RGB and YUV values are in the range [0, 255] - - Args: - rgb (Tensor): 4D YUV444 tensor to convert in [0. 255.] - - Returns: - Tensor: the resulting RGB tensor in [0. 255.] - """ - assert ( - len(yuv.size()) == 4 - ), f"yuv2rgb input must be a 4D tensor [B, 3, H, W]. Data size: {yuv.size()}" - assert ( - yuv.size()[1] == 3 - ), f"yuv2rgb input must have 3 channels. Data size: {yuv.size()}" - - y, u, v = yuv.split(1, dim=1) - r = ( - 1.0 * y - + -0.000007154783816076815 * u - + 1.4019975662231445 * v - - 179.45477266423404 - ) - g = 1.0 * y + -0.3441331386566162 * u + -0.7141380310058594 * v + 135.45870971679688 - b = ( - 1.0 * y - + 1.7720025777816772 * u - + 0.00001542569043522235 * v - - 226.8183044444304 - ) - rgb = torch.cat((r, g, b), dim=1) - return rgb diff --git a/coolchic/enc/visu/__init__.py b/coolchic/enc/visu/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/coolchic/enc/visu/console.py b/coolchic/enc/visu/console.py new file mode 100644 index 00000000..535ddf73 --- /dev/null +++ b/coolchic/enc/visu/console.py @@ -0,0 +1,258 @@ +# Software Name: Cool-Chic +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange +# SPDX-License-Identifier: BSD 3-Clause "New" +# +# This software is distributed under the BSD-3-Clause license. +# +# Authors: see CONTRIBUTORS.md + +from torch import nn + + +def pretty_string_ups(upsampling: nn.Module, header: str) -> str: + """Get a nice string ready to be printed which displays the different layers + of the upsampling step. + + Something like: + + ..code :: + + header + + +-------------+ + y0 -> | 8x8 TConv2d | -----+ + +-------------+ | + v + +------------+ +-----+ +-------------+ + y1 -> | 7x7 Conv2d | -> | cat | -> | 8x8 TConv2d | -----+ + +------------+ +-----+ +-------------+ | + v + +------------+ +-----+ +-------------+ + y2 ------------------------------> | 7x7 Conv2d | -> | cat | -> | 8x8 TConv2d | -> dense + +------------+ +-----+ +-------------+ + Args: + upsampling (nn.Module): The upsampling module to print + header (str): A string to append before the layers + + Returns: + str: A nice string presenting the upsampling. + """ + + lines = [] + + assert upsampling.n_ups_kernel == upsampling.n_ups_preconcat_kernel, ( + "Textual representation of the upsampling module excepts to have the" + "same number of upsampling and pre-concatenation kernels. Found" + f"n_ups_kernel = {upsampling.n_ups_kernel} and n_ups_preconcat_kernel =" + f" {upsampling.n_ups_preconcat_kernel}." + ) + + # Useful thing to print + cat_box = [ + "+-----+", + "| cat |", + "+-----+", + ] + # ! Warning no space before short_arrow here! + no_space_short_arrow = "->" + short_arrow = " -> " + + shorten_names = { + "ParametrizedUpsamplingSeparableSymmetricConv2d": "Conv2d", + "ParametrizedUpsamplingSeparableSymmetricConvTranspose2d": "TConv2d", + } + + lateral_offset = 0 + + for idx_ups in range(upsampling.n_ups_kernel + 1): + latent_name = f"y{idx_ups:<1} " + mid = f"{latent_name}{'-' * lateral_offset}{no_space_short_arrow} " + hori_border = ["", ""] + + for i in range(len(hori_border)): + hori_border[i] += " " * len(mid) + + # First (smallest) latent does not have a pre-concatenation filter + if idx_ups != 0: + conv_lay = upsampling.conv2ds[i] + layer_name = shorten_names.get( + type(conv_lay).__name__, type(conv_lay).__name__ + ) + k_size = conv_lay.target_k_size + inside_box = f" {k_size}x{k_size} {layer_name} " + mid += f"|{inside_box}|{short_arrow}" + + # Add a concatenation box + for i in range(len(hori_border)): + hori_border[i] += f"+{'-' * len(inside_box)}+{' ' * len(short_arrow)}" + + mid += cat_box[1] + short_arrow + hori_border[0] += cat_box[0] + hori_border[1] += cat_box[2] + + for i in range(len(hori_border)): + hori_border[i] += f"{' ' * len(short_arrow)}" + + lateral_offset = len(mid) - len(latent_name) - len(no_space_short_arrow) - 1 + + tconv_lay = upsampling.conv_transpose2ds[i] + layer_name = shorten_names.get( + type(tconv_lay).__name__, type(tconv_lay).__name__ + ) + k_size = tconv_lay.target_k_size + inside_box = f" {k_size}x{k_size} {layer_name} " + mid += f"|{inside_box}|" + + for i in range(len(hori_border)): + hori_border[i] += f"+{'-' * len(inside_box)}+" + + concat_arrow = ( + " " # Skip one space as in long arrow + # Account for the long arrow and half of the cat box. + + "-" * (len(no_space_short_arrow) + len(cat_box[0]) // 2) + ) + + # Last line is a bit specific + if idx_ups == upsampling.n_ups_kernel: + mid += f"{short_arrow}dense" + lines += [hori_border[0], mid, hori_border[1]] + + else: + mid += concat_arrow + "+" + hori_border[1] += " " * len(concat_arrow) + "|" + last_line = " " * (len(mid) - 1) + "v" + lines += [hori_border[0], mid, hori_border[1], last_line] + + return header + "\n".join(lines) + + +def pretty_string_nn( + layers: nn.Sequential, header: str, input_str: str, output_str: str +) -> str: + """Get a nice string ready to be printed which displays the different layers + of a neural network. + + Something like: + + ..code :: + + header + +--------------------------+ + | | + | v + | +---------------+ +-----+ +------+ +---------------+ + input_str ---> | Linear 8 -> 8 | -> | + | -> | ReLU | ---> | Linear 8 -> 2 | ---> output_str + +---------------+ +-----+ +------+ +---------------+ + + Args: + layers (nn.Sequential): The successive layer of the neural network. + header (str): A string to append before the layers + input_str (str): A string describing the input of the NN. + output_str (str): A string describing the output of the NN. + + Returns: + str: A nice string presenting the neural network architecture. + """ + + lines = [ + "", # For residual connexions + "", # For residual connexions + "", # For residual connexions + "", # For blocks in themselves + "", # For blocks in themselves + "", # For blocks in themselves + ] + + # Name of the layer appears here + idx_mid_block = len(lines) - 2 + + # Horizontal borders above and below the block name + idx_top_block = len(lines) - 3 + idx_bot_block = len(lines) - 1 + top_bot_blocks = [idx_bot_block, idx_top_block] + + # First three lines are for residual connexions + res_blocks = list(range(3)) + + lines[idx_mid_block] = input_str + for idx in top_bot_blocks + res_blocks: + lines[idx] += " " * len(lines[idx_mid_block]) + + # Useful thing to print + plus_box = [ + "+-----+", + "| + |", + "+-----+", + ] + short_arrow = " -> " + long_arrow = " ---> " + + shorten_names = {"ArmLinear": "Linear", "SynthesisConv2d": "Conv2d"} + + # Print one layer after the other + for lay in layers: + is_non_linearity = ( + isinstance(lay, nn.ReLU) + or isinstance(lay, nn.Identity) + or isinstance(lay, nn.LeakyReLU) + ) + + arrow_str = short_arrow if is_non_linearity else long_arrow + + layer_name = shorten_names.get(type(lay).__name__, type(lay).__name__) + if layer_name == "Identity": + continue + + inside_box = f" {layer_name} " + if not is_non_linearity: + inside_box += f"{lay.in_channels} -> {lay.out_channels} " + + if hasattr(lay, "kernel_size"): + inside_box = f" {lay.kernel_size}x{lay.kernel_size}{inside_box}" + + lines[idx_mid_block] += f"{arrow_str}|{inside_box}|" + + for idx in top_bot_blocks: + lines[idx] += f"{' ' * len(arrow_str)}+{'-' * len(inside_box)}+" + + is_residual = False + if hasattr(lay, "residual"): + is_residual = lay.residual + + if is_residual: + half_arrow_len = len(arrow_str) // 2 + for idx in res_blocks: + lines[idx] += " " * half_arrow_len + + res_arrow_len = ( + half_arrow_len + - 1 # The remaining half of the input arrow, minus 1 + + len(inside_box) + + 2 # The entire box + 2 for the edges + + len(short_arrow) # The short arrow from the layer box to the add box + + len(plus_box[0]) // 2 # Half of the plus box. + ) + lines[0] += f"+{'-' * res_arrow_len}+" + lines[1] += f"|{' ' * res_arrow_len}|" + lines[2] += f"|{' ' * res_arrow_len}v" + + for idx in res_blocks: + lines[idx] += " " * (len(plus_box[0]) // 2) + + else: + for idx in res_blocks: + lines[idx] += f"{' ' * len(arrow_str)} {' ' * len(inside_box)} " + + if is_residual: + arrow_str = short_arrow + lines[idx_mid_block] += f"{arrow_str}{plus_box[1]}" + for idx in top_bot_blocks: + lines[idx] += f"{' ' * len(arrow_str)}{plus_box[0]}" + + lines[3] = list(lines[3]) + lines[3][-(res_arrow_len + half_arrow_len + 2)] = "|" + lines[3] = "".join(lines[3]) + + lines[idx_mid_block] += long_arrow + output_str + + return header + "\n".join(lines) diff --git a/coolchic/encode.py b/coolchic/encode.py index 11bf19d2..0962fa67 100644 --- a/coolchic/encode.py +++ b/coolchic/encode.py @@ -1,27 +1,31 @@ - # Software Name: Cool-Chic +# Software Name: Cool-Chic # SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange # SPDX-License-Identifier: BSD 3-Clause "New" # # This software is distributed under the BSD-3-Clause license. # -# Authors: Theo Ladune -# Pierrick Philippe +# Authors: see CONTRIBUTORS.md + import os -import sys -import torch import subprocess -import configargparse +import sys +import configargparse +import torch from enc.component.coolchic import CoolChicEncoderParameter from enc.component.video import ( - VideoEncoder, FrameEncoderManager, + VideoEncoder, load_video_encoder, ) from enc.utils.codingstructure import CodingStructure -from enc.utils.misc import get_best_device - +from enc.utils.misc import TrainingExitCode, get_best_device +from enc.utils.parsecli import ( + get_coding_structure_from_args, + get_coolchic_param_from_args, + get_manager_from_args, +) """ Use this file to train i.e. encode a GOP i.e. something which starts with one @@ -30,7 +34,6 @@ """ if __name__ == "__main__": - # =========================== Parse arguments =========================== # # By increasing priority order, the arguments work as follows: # @@ -47,20 +50,20 @@ parser = configargparse.ArgumentParser() # -------- These arguments are not in the configuration files parser.add( - "-i", "--input", + "-i", + "--input", help="Path of the input image. Either .png (RGB444) or .yuv (YUV420)", type=str, ) parser.add( - "-o", "--output", + "-o", + "--output", help="Path of the compressed bitstream. If empty, no bitstream is written", type=str, default="", ) - parser.add( - "--workdir", help="Path of the working_directory", type=str, default="." - ) + parser.add("--workdir", help="Path of the working_directory", type=str, default=".") parser.add("--lmbda", help="Rate constraint", type=float, default=1e-3) parser.add( "--job_duration_min", @@ -70,13 +73,9 @@ ) # -------- Configuration files - parser.add( - "--enc_cfg", is_config_file=True, help="Encoder configuration file" - ) + parser.add("--enc_cfg", is_config_file=True, help="Encoder configuration file") - parser.add( - "--dec_cfg", is_config_file=True, help="Decoder configuration file" - ) + parser.add("--dec_cfg", is_config_file=True, help="Decoder configuration file") # -------- These arguments are in the configuration files @@ -94,18 +93,14 @@ default=0, ) - parser.add( - "--start_lr", help="Initial learning rate", type=float, default=1e-2 - ) + parser.add("--start_lr", help="Initial learning rate", type=float, default=1e-2) parser.add( "--n_itr", help="Maximum number of iterations per phase", type=int, default=int(1e4), ) - parser.add( - "--n_train_loops", help="Number of training loops", type=int, default=1 - ) + parser.add("--n_train_loops", help="Number of training loops", type=int, default=1) parser.add( "--recipe", help='Recipe type. Either "c3x" or "debug".', @@ -118,19 +113,27 @@ "--layers_synthesis", type=str, default="40-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none", - help="Syntax example for the synthesis:" - " 12-1-linear-relu,12-1-residual-relu,X-1-linear-relu,X-3-residual-none" - "This is a 4 layers synthesis. Now the output layer (computing the final RGB" - "values) must be specified i.e. a 12,12 should now be called a 12,12,3. Each layer" - "is described using the following syntax:" - "---. " - " is the number of output features. If set to X, this is replaced by the" - "number of required output features i.e. 3 for a RGB or YUV frame." - " is the spatial dimension of the kernel. Use 1 to mimic an MLP." - " is either 'linear' for a standard conv or 'residual' for a residual" - " block i.e. layer(x) = x + conv(x). Can be'none' for no" - " non-linearity, 'relu' for a ReLU" + help="Syntax example: " + " " + " 12-1-linear-relu,12-1-residual-relu,X-1-linear-relu,X-3-residual-none " + " " + " This is a 4 layers synthesis. Each layer is separated by comas and is " + "described using the following syntax: " + " " + " ---. " + " " + " is the number of output features. If set to X, this is replaced by the " + "number of required output features i.e. 3 for a RGB or YUV frame. " + " " + " is the spatial dimension of the kernel. Use 1 to mimic an MLP. " + " " + " is either 'linear' for a standard conv or 'residual' for a convolution " + "with a residual connexion block i.e. layer(x) = x + conv(x). " + " " + " Can be 'none' for no non-linearity, 'relu' for a ReLU " + "non-linearity. ", ) + parser.add( "--arm", type=str, @@ -141,30 +144,37 @@ ) parser.add( - "--n_ft_per_res", - type=str, - default="1,1,1,1,1,1,1", - help="Number of feature for each latent resolution. e.g. --n_ft_per_res=1,2,2,2,3,3,3" - " for 7 latent grids with variable resolutions.", + "--ups_k_size", + type=int, + default=8, + help="Upsampling kernel size for the transposed convolutions. " + "Must be even and >= 4.", ) parser.add( - "--upsampling_kernel_size", - help="upsampling kernel size (โ‰ฅ4 and multiple of 2)", + "--ups_preconcat_k_size", type=int, - default=8, + default=7, + help="Upsampling kernel size for the pre-concatenation convolutions. " + "Must be odd.", ) + parser.add( - "--static_upsampling_kernel", - help="Use this flag to **not** learn the upsampling kernel", - action="store_true", - default=False, + "--n_ft_per_res", + type=str, + default="1,1,1,1,1,1,1", + help="Number of feature for each latent resolution. e.g. " + " --n_ft_per_res_residue=1,1,1,1,1,1,1 " + " for 7 latent grids with variable resolutions. " + " Parameterize the residue decoder.", ) args = parser.parse_args() print(args) print("----------") - print(parser.format_values()) # useful for logging where different settings came from + print( + parser.format_values() + ) # useful for logging where different settings came from # =========================== Parse arguments =========================== # # =========================== Parse arguments =========================== # @@ -175,27 +185,25 @@ video_encoder = load_video_encoder(path_video_encoder) else: - start_print = ( - '\n\n' - '*----------------------------------------------------------------------------------------------------------*\n' - '| |\n' - '| |\n' - '| ,gggg, |\n' + "\n\n" + "*----------------------------------------------------------------------------------------------------------*\n" + "| |\n" + "| |\n" + "| ,gggg, |\n" '| ,88"""Y8b, ,dPYb, ,dPYb, |\n' - '| d8" `Y8 IP\'`Yb IP\'`Yb |\n' - '| d8\' 8b d8 I8 8I I8 8I gg |\n' - '| ,8I "Y88P\' I8 8\' I8 8\' "" |\n' - '| I8\' ,ggggg, ,ggggg, I8 dP aaaaaaaa ,gggg, I8 dPgg, gg ,gggg, |\n' + "| d8\" `Y8 IP'`Yb IP'`Yb |\n" + "| d8' 8b d8 I8 8I I8 8I gg |\n" + "| ,8I \"Y88P' I8 8' I8 8' \"\" |\n" + "| I8' ,ggggg, ,ggggg, I8 dP aaaaaaaa ,gggg, I8 dPgg, gg ,gggg, |\n" '| d8 dP" "Y8ggg dP" "Y8ggg I8dP """""""" dP" "Yb I8dP" "8I 88 dP" "Yb |\n' - '| Y8, i8\' ,8I i8\' ,8I I8P i8\' I8P I8 88 i8\' |\n' - '| `Yba,,_____, ,d8, ,d8\' ,d8, ,d8\' ,d8b,_ ,d8,_ _,d8 I8,_,88,_,d8,_ _ |\n' + "| Y8, i8' ,8I i8' ,8I I8P i8' I8P I8 88 i8' |\n" + "| `Yba,,_____, ,d8, ,d8' ,d8, ,d8' ,d8b,_ ,d8,_ _,d8 I8,_,88,_,d8,_ _ |\n" '| `"Y8888888 P"Y8888P" P"Y8888P" 8P\'"Y88 P""Y8888PP88P `Y88P""Y8P""Y8888PP |\n' - '| |\n' - '| |\n' - '| version 3.3 ยฉ 2023-2024 Orange |\n' - '*----------------------------------------------------------------------------------------------------------*\n' - + "| |\n" + "| |\n" + "| version 3.4, Nov. 2024 ยฉ 2023-2024 Orange |\n" + "*----------------------------------------------------------------------------------------------------------*\n" ) print(start_print) @@ -207,72 +215,19 @@ f_out.write(str(args)) f_out.write("\n") f_out.write("----------\n") - f_out.write(parser.format_values()) # useful for logging where different settings came from - - # ----- Create coding configuration - assert args.intra_period >= 0 and args.intra_period <= 255, ( - f"Intra period should be " f" in [0, 255]. Found {args.intra_period}" - ) - - assert args.p_period >= 0 and args.p_period <= 255, ( - f"P period should be " f" in [0, 255]. Found {args.p_period}" - ) - - is_image = ( - args.input.endswith(".png") - or args.input.endswith(".PNG") - or args.input.endswith(".jpeg") - or args.input.endswith(".JPEG") - or args.input.endswith(".jpg") - or args.input.endswith(".JPG") - ) - - if is_image: - assert args.intra_period == 0 and args.p_period == 0, ( - f"Encoding a PNG or JPEG image {args.input} must be done with " - "intra_period = 0 and p_period = 0. Found intra_period = " - f"{args.intra_period} and p_period = {args.p_period}" - ) - - coding_config = CodingStructure( - intra_period=args.intra_period, - p_period=args.p_period, - seq_name=os.path.basename(args.input).split(".")[0], - ) - - # Parse arguments - layers_synthesis = [x for x in args.layers_synthesis.split(",") if x != ""] - n_ft_per_res = [int(x) for x in args.n_ft_per_res.split(",") if x != ""] - - assert set(n_ft_per_res) == {1}, ( - f"--n_ft_per_res should only contains 1. Found {args.n_ft_per_res}" - ) - - assert len(args.arm.split(",")) == 2, ( - f"--arm format should be X,Y." f" Found {args.arm}" - ) - - dim_arm, n_hidden_layers_arm = [int(x) for x in args.arm.split(",")] + f_out.write( + parser.format_values() + ) # useful for logging where different settings came from + # ----- Parse arguments & construct video encoder + coding_structure = CodingStructure(**get_coding_structure_from_args(args)) coolchic_encoder_parameter = CoolChicEncoderParameter( - layers_synthesis=layers_synthesis, - dim_arm=dim_arm, - n_hidden_layers_arm=n_hidden_layers_arm, - n_ft_per_res=n_ft_per_res, - upsampling_kernel_size=args.upsampling_kernel_size, - static_upsampling_kernel=args.static_upsampling_kernel, - ) - - frame_encoder_manager = FrameEncoderManager( - preset_name=args.recipe, - start_lr=args.start_lr, - lmbda=args.lmbda, - n_loops=args.n_train_loops, - n_itr=args.n_itr, + **get_coolchic_param_from_args(args) ) + frame_encoder_manager = FrameEncoderManager(**get_manager_from_args(args)) video_encoder = VideoEncoder( - coding_structure=coding_config, + coding_structure=coding_structure, shared_coolchic_parameter=coolchic_encoder_parameter, shared_frame_encoder_manager=frame_encoder_manager, ) @@ -281,29 +236,8 @@ device = get_best_device() print(f'{"Device":<20}: {device}') - # # ====================== Torchscript JIT parameters ===================== # - # # From https://github.com/pytorch/pytorch/issues/52286 - # # This is no longer the case with the with torch.jit.fuser - # # ! This gives a significant (+25 %) speed up - # torch._C._jit_set_profiling_executor(False) - # torch._C._jit_set_texpr_fuser_enabled(False) - # torch._C._jit_set_profiling_mode(False) - - # torch.set_float32_matmul_precision("high") - # # ====================== Torchscript JIT parameters ===================== # - - if device == "cpu": - # the number of cores is adjusted wrt to the slurm variable if exists - n_cores = os.getenv("SLURM_JOB_CPUS_PER_NODE") - # otherwise use the machine cpu count - if n_cores is None: - n_cores = os.cpu_count() - - n_cores = int(n_cores) - print(f'{"CPU cores":<20}: {n_cores}') - - elif device == "cuda:0": - # ! This one makes the training way faster! + # This makes the training faster + if device == "cuda:0": torch.backends.cudnn.benchmark = True print(f"\n{video_encoder.coding_structure.pretty_string()}\n") @@ -319,10 +253,10 @@ video_encoder.save(video_encoder_savepath) # Bitstream - if args.output != "": + if args.output != "" and exit_code == TrainingExitCode.END: from enc.bitstream.encode import encode_video + # video_encoder = load_video_encoder(video_encoder_savepath) encode_video(video_encoder, args.output, hls_sig_blksize=16) sys.exit(exit_code.value) - diff --git a/docs/source/_templates/partials/webfonts.html b/docs/source/_templates/partials/webfonts.html new file mode 100644 index 00000000..de8d9b38 --- /dev/null +++ b/docs/source/_templates/partials/webfonts.html @@ -0,0 +1,11 @@ + + + + + \ No newline at end of file diff --git a/docs/source/assets/clic20-pro-valid/all_complexity.png b/docs/source/assets/clic20-pro-valid/all_complexity.png new file mode 100644 index 00000000..099c3886 Binary files /dev/null and b/docs/source/assets/clic20-pro-valid/all_complexity.png differ diff --git a/docs/source/assets/clic20-pro-valid/concat_img.png b/docs/source/assets/clic20-pro-valid/concat_img.png new file mode 100644 index 00000000..288959c3 Binary files /dev/null and b/docs/source/assets/clic20-pro-valid/concat_img.png differ diff --git a/docs/source/assets/clic20-pro-valid/perf_complexity.png b/docs/source/assets/clic20-pro-valid/perf_complexity.png index 4ab785d1..8c423198 100644 Binary files a/docs/source/assets/clic20-pro-valid/perf_complexity.png and b/docs/source/assets/clic20-pro-valid/perf_complexity.png differ diff --git a/docs/source/assets/clic20-pro-valid/perf_decoding_time.png b/docs/source/assets/clic20-pro-valid/perf_decoding_time.png index c5dd3ead..9a35289c 100644 Binary files a/docs/source/assets/clic20-pro-valid/perf_decoding_time.png and b/docs/source/assets/clic20-pro-valid/perf_decoding_time.png differ diff --git a/docs/source/assets/clic20-pro-valid/rd.png b/docs/source/assets/clic20-pro-valid/rd.png index 4b0e2e3e..e5501047 100644 Binary files a/docs/source/assets/clic20-pro-valid/rd.png and b/docs/source/assets/clic20-pro-valid/rd.png differ diff --git a/docs/source/assets/concat_img.sh b/docs/source/assets/concat_img.sh new file mode 100755 index 00000000..7784b886 --- /dev/null +++ b/docs/source/assets/concat_img.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +complexity_height=700 +final_width=1000 + +for dataset in kodak clic20-pro-valid +do + # Concatenate the images horizontally and resize them if they don't have the same height + convert +append $dataset/perf_complexity.png $dataset/perf_decoding_time.png -resize x$complexity_height $dataset/all_complexity.png + convert -append $dataset/rd.png $dataset/all_complexity.png -resize $final_width $dataset/concat_img.png +done + +dataset=jvet +for class in B C D E F BCDEF +do + # Concatenate the images horizontally and resize them if they don't have the same height + convert +append $dataset/perf_complexity_class$class.png $dataset/perf_decoding_time_class$class.png -resize x$complexity_height $dataset/all_complexity_class$class.png + convert -append $dataset/rd_class$class.png $dataset/all_complexity_class$class.png -resize $final_width $dataset/concat_img_class$class.png +done diff --git a/docs/source/assets/coolchic-logo-light.png b/docs/source/assets/coolchic-logo-light.png index 708b49e5..697a9a49 100644 Binary files a/docs/source/assets/coolchic-logo-light.png and b/docs/source/assets/coolchic-logo-light.png differ diff --git a/docs/source/assets/favicon.pdf b/docs/source/assets/favicon.pdf deleted file mode 100644 index f6cdd17b..00000000 Binary files a/docs/source/assets/favicon.pdf and /dev/null differ diff --git a/docs/source/assets/jvet/all_complexity_classB.png b/docs/source/assets/jvet/all_complexity_classB.png new file mode 100644 index 00000000..4f4b4f12 Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classB.png differ diff --git a/docs/source/assets/jvet/all_complexity_classBCDEF.png b/docs/source/assets/jvet/all_complexity_classBCDEF.png new file mode 100644 index 00000000..be26ebc7 Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classBCDEF.png differ diff --git a/docs/source/assets/jvet/all_complexity_classC.png b/docs/source/assets/jvet/all_complexity_classC.png new file mode 100644 index 00000000..3347084d Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classC.png differ diff --git a/docs/source/assets/jvet/all_complexity_classD.png b/docs/source/assets/jvet/all_complexity_classD.png new file mode 100644 index 00000000..954c1cd8 Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classD.png differ diff --git a/docs/source/assets/jvet/all_complexity_classE.png b/docs/source/assets/jvet/all_complexity_classE.png new file mode 100644 index 00000000..ce2570fa Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classE.png differ diff --git a/docs/source/assets/jvet/all_complexity_classF.png b/docs/source/assets/jvet/all_complexity_classF.png new file mode 100644 index 00000000..18491daa Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classF.png differ diff --git a/docs/source/assets/jvet/concat_img_classB.png b/docs/source/assets/jvet/concat_img_classB.png new file mode 100644 index 00000000..fbdd1654 Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classB.png differ diff --git a/docs/source/assets/jvet/concat_img_classBCDEF.png b/docs/source/assets/jvet/concat_img_classBCDEF.png new file mode 100644 index 00000000..fb8d6c49 Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classBCDEF.png differ diff --git a/docs/source/assets/jvet/concat_img_classC.png b/docs/source/assets/jvet/concat_img_classC.png new file mode 100644 index 00000000..d74cca1e Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classC.png differ diff --git a/docs/source/assets/jvet/concat_img_classD.png b/docs/source/assets/jvet/concat_img_classD.png new file mode 100644 index 00000000..882efea9 Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classD.png differ diff --git a/docs/source/assets/jvet/concat_img_classE.png b/docs/source/assets/jvet/concat_img_classE.png new file mode 100644 index 00000000..20935e56 Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classE.png differ diff --git a/docs/source/assets/jvet/concat_img_classF.png b/docs/source/assets/jvet/concat_img_classF.png new file mode 100644 index 00000000..534516db Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classF.png differ diff --git a/docs/source/assets/jvet/perf_complexity_classB.png b/docs/source/assets/jvet/perf_complexity_classB.png index 6be340b4..45b763d1 100644 Binary files a/docs/source/assets/jvet/perf_complexity_classB.png and b/docs/source/assets/jvet/perf_complexity_classB.png differ diff --git a/docs/source/assets/jvet/perf_complexity_classBCDEF.png b/docs/source/assets/jvet/perf_complexity_classBCDEF.png index d1ba6a33..8a900543 100644 Binary files a/docs/source/assets/jvet/perf_complexity_classBCDEF.png and b/docs/source/assets/jvet/perf_complexity_classBCDEF.png differ diff --git a/docs/source/assets/jvet/perf_complexity_classC.png b/docs/source/assets/jvet/perf_complexity_classC.png index 81c0fd52..ae52b65d 100644 Binary files a/docs/source/assets/jvet/perf_complexity_classC.png and b/docs/source/assets/jvet/perf_complexity_classC.png differ diff --git a/docs/source/assets/jvet/perf_complexity_classD.png b/docs/source/assets/jvet/perf_complexity_classD.png index 790b997d..1dc70366 100644 Binary files a/docs/source/assets/jvet/perf_complexity_classD.png and b/docs/source/assets/jvet/perf_complexity_classD.png differ diff --git a/docs/source/assets/jvet/perf_complexity_classE.png b/docs/source/assets/jvet/perf_complexity_classE.png index 4e9d6b92..0df7e289 100644 Binary files a/docs/source/assets/jvet/perf_complexity_classE.png and b/docs/source/assets/jvet/perf_complexity_classE.png differ diff --git a/docs/source/assets/jvet/perf_complexity_classF.png b/docs/source/assets/jvet/perf_complexity_classF.png index 4c0943a7..2c198731 100644 Binary files a/docs/source/assets/jvet/perf_complexity_classF.png and b/docs/source/assets/jvet/perf_complexity_classF.png differ diff --git a/docs/source/assets/jvet/perf_decoding_time_classB.png b/docs/source/assets/jvet/perf_decoding_time_classB.png index a0063b91..87d595de 100644 Binary files a/docs/source/assets/jvet/perf_decoding_time_classB.png and b/docs/source/assets/jvet/perf_decoding_time_classB.png differ diff --git a/docs/source/assets/jvet/perf_decoding_time_classBCDEF.png b/docs/source/assets/jvet/perf_decoding_time_classBCDEF.png index 6fd0daeb..003e3d6c 100644 Binary files a/docs/source/assets/jvet/perf_decoding_time_classBCDEF.png and b/docs/source/assets/jvet/perf_decoding_time_classBCDEF.png differ diff --git a/docs/source/assets/jvet/perf_decoding_time_classC.png b/docs/source/assets/jvet/perf_decoding_time_classC.png index 14699d8e..d50eb4f4 100644 Binary files a/docs/source/assets/jvet/perf_decoding_time_classC.png and b/docs/source/assets/jvet/perf_decoding_time_classC.png differ diff --git a/docs/source/assets/jvet/perf_decoding_time_classD.png b/docs/source/assets/jvet/perf_decoding_time_classD.png index d6365075..b0253a86 100644 Binary files a/docs/source/assets/jvet/perf_decoding_time_classD.png and b/docs/source/assets/jvet/perf_decoding_time_classD.png differ diff --git a/docs/source/assets/jvet/perf_decoding_time_classE.png b/docs/source/assets/jvet/perf_decoding_time_classE.png index 37992f72..1c95da44 100644 Binary files a/docs/source/assets/jvet/perf_decoding_time_classE.png and b/docs/source/assets/jvet/perf_decoding_time_classE.png differ diff --git a/docs/source/assets/jvet/perf_decoding_time_classF.png b/docs/source/assets/jvet/perf_decoding_time_classF.png index 7b3ce8b7..e4215494 100644 Binary files a/docs/source/assets/jvet/perf_decoding_time_classF.png and b/docs/source/assets/jvet/perf_decoding_time_classF.png differ diff --git a/docs/source/assets/jvet/rd_classB.png b/docs/source/assets/jvet/rd_classB.png index 94bac994..87106368 100644 Binary files a/docs/source/assets/jvet/rd_classB.png and b/docs/source/assets/jvet/rd_classB.png differ diff --git a/docs/source/assets/jvet/rd_classBCDEF.png b/docs/source/assets/jvet/rd_classBCDEF.png index 911b2725..2e53db0f 100644 Binary files a/docs/source/assets/jvet/rd_classBCDEF.png and b/docs/source/assets/jvet/rd_classBCDEF.png differ diff --git a/docs/source/assets/jvet/rd_classC.png b/docs/source/assets/jvet/rd_classC.png index 6097c9b7..e8e74172 100644 Binary files a/docs/source/assets/jvet/rd_classC.png and b/docs/source/assets/jvet/rd_classC.png differ diff --git a/docs/source/assets/jvet/rd_classD.png b/docs/source/assets/jvet/rd_classD.png index 5f244448..36bdb3ca 100644 Binary files a/docs/source/assets/jvet/rd_classD.png and b/docs/source/assets/jvet/rd_classD.png differ diff --git a/docs/source/assets/jvet/rd_classE.png b/docs/source/assets/jvet/rd_classE.png index 9eded54c..8ac86b3e 100644 Binary files a/docs/source/assets/jvet/rd_classE.png and b/docs/source/assets/jvet/rd_classE.png differ diff --git a/docs/source/assets/jvet/rd_classF.png b/docs/source/assets/jvet/rd_classF.png index dbc51d86..6e432665 100644 Binary files a/docs/source/assets/jvet/rd_classF.png and b/docs/source/assets/jvet/rd_classF.png differ diff --git a/docs/source/assets/kodak/all_complexity.png b/docs/source/assets/kodak/all_complexity.png new file mode 100644 index 00000000..58af51e8 Binary files /dev/null and b/docs/source/assets/kodak/all_complexity.png differ diff --git a/docs/source/assets/kodak/concat_img.png b/docs/source/assets/kodak/concat_img.png new file mode 100644 index 00000000..b71e96ad Binary files /dev/null and b/docs/source/assets/kodak/concat_img.png differ diff --git a/docs/source/assets/kodak/perf_complexity.png b/docs/source/assets/kodak/perf_complexity.png index f9bf3d77..cba351b9 100644 Binary files a/docs/source/assets/kodak/perf_complexity.png and b/docs/source/assets/kodak/perf_complexity.png differ diff --git a/docs/source/assets/kodak/perf_decoding_time.png b/docs/source/assets/kodak/perf_decoding_time.png index 59697e07..8b6a6f34 100644 Binary files a/docs/source/assets/kodak/perf_decoding_time.png and b/docs/source/assets/kodak/perf_decoding_time.png differ diff --git a/docs/source/assets/kodak/rd.png b/docs/source/assets/kodak/rd.png index de7e3d26..508d15bd 100644 Binary files a/docs/source/assets/kodak/rd.png and b/docs/source/assets/kodak/rd.png differ diff --git a/docs/source/assets/logo_concat.png b/docs/source/assets/logo_concat.png new file mode 100644 index 00000000..2beac2fe Binary files /dev/null and b/docs/source/assets/logo_concat.png differ diff --git a/docs/source/assets/rd-image-clic20-validpro.png b/docs/source/assets/rd-image-clic20-validpro.png deleted file mode 100644 index 4b74f68a..00000000 Binary files a/docs/source/assets/rd-image-clic20-validpro.png and /dev/null differ diff --git a/docs/source/assets/rd-image-jvet.png b/docs/source/assets/rd-image-jvet.png deleted file mode 100644 index a9963b0f..00000000 Binary files a/docs/source/assets/rd-image-jvet.png and /dev/null differ diff --git a/docs/source/assets/rd-video-ldp-clic24-validsubset.png b/docs/source/assets/rd-video-ldp-clic24-validsubset.png deleted file mode 100644 index dbeb0e2e..00000000 Binary files a/docs/source/assets/rd-video-ldp-clic24-validsubset.png and /dev/null differ diff --git a/docs/source/assets/rd-video-ra-clic24-validsubset.png b/docs/source/assets/rd-video-ra-clic24-validsubset.png deleted file mode 100644 index 544135ac..00000000 Binary files a/docs/source/assets/rd-video-ra-clic24-validsubset.png and /dev/null differ diff --git a/docs/source/code_documentation/encoder/component/core/index.rst b/docs/source/code_documentation/encoder/component/core/index.rst index 71644ace..9d63a1f9 100644 --- a/docs/source/code_documentation/encoder/component/core/index.rst +++ b/docs/source/code_documentation/encoder/component/core/index.rst @@ -8,7 +8,7 @@ how these modules are related. .. toctree:: - ARM - Quantizer - Upsampling - Synthesis + arm + quantizer + upsampling + synthesis diff --git a/docs/source/code_documentation/encoder/component/core/upsampling.rst b/docs/source/code_documentation/encoder/component/core/upsampling.rst index 1f908985..5e0f192a 100644 --- a/docs/source/code_documentation/encoder/component/core/upsampling.rst +++ b/docs/source/code_documentation/encoder/component/core/upsampling.rst @@ -9,5 +9,8 @@ Upsampling .. autoclass:: Upsampling :members: -.. autoclass:: UpsamplingConvTranspose2d +.. autoclass:: UpsamplingSeparableSymmetricConvTranspose2d + :members: + +.. autoclass:: UpsamplingSeparableSymmetricConv2d :members: diff --git a/docs/source/code_documentation/encoder/component/index.rst b/docs/source/code_documentation/encoder/component/index.rst index 6eca23d0..0a18f68e 100644 --- a/docs/source/code_documentation/encoder/component/index.rst +++ b/docs/source/code_documentation/encoder/component/index.rst @@ -3,25 +3,26 @@ Component Components gather all the modules of the enc. - * ``VideoEncoder`` is used to compress a video *i.e.* one intra frame - followed by zero or more inter frames. It contains one or more - ``FrameEncoder``. + * ``video.py`` implements the class ``VideoEncoder`` which is used to compress a + video *i.e.* one intra frame followed by zero or more inter frames. It + contains one or more instances of ``FrameEncoder``. - * ``FrameEncoder`` is used to compress a frame. It contains one - ``CoolChicEncoder``. + * ``frame.py`` implements the class ``FrameEncoder`` is used to compress a + frame. To do so, it relies on one ``CoolChicEncoder``. - * ``CoolChicEncoder`` is the main coding engine of the codec. It is composed - of latent grids, an auto-regressive module, an upsampling and a synthesis. + * ``coolchic.py`` implements the class ``CoolChicEncoder``, which is the main + coding engine of the codec. It is composed of latent grids, an + auto-regressive module, an upsampling and a synthesis. - * ``Core`` contains all the modules stated above and required by a + * ``core/`` contains all the modules stated above and required by a ``CoolChicEncoder`` .. toctree:: - VideoEncoder