Cool-chic (pronounced /kul สik/ as in French ๐ฅ๐ง๐ท) is
-is a low-complexity neural image codec based on overfitting. It offers image coding
-performance competitive with *H.266/VVC for 2000 multiplications* per decoded
+a low-complexity neural image codec based on overfitting. It offers image coding
+performance competitive with **H.266/VVC for 1000 multiplications** per decoded
pixel.
+
+
+
+
+#### ๐ **Coding performance**: Cool-chic compresses images as well as H.266/VVC ๐
+#### ๐ **Fast CPU-only decoder**: Decode a 1280x720 image in 100 ms on CPU with our decoder written in C ๐
+#### ๐ฅ **Fixed-point decoder**: Fixed-point arithmetic at the decoder for bit-exact results on different hardwares ๐ฅ
+#### ๐ผ๏ธ **I/O format**: Encode PNG, PPM and YUV file with a bitdepth of 8 to 16 bits ๐ผ๏ธ
+
+
+
#
-### Current & future features
-
-- Coding performance
- - โ On par with VVC for image coding
- - โ Upcoming improved Cool-chic video
-- I/O format
- - โ PPM for 8-bit RGB images, yuv420 8-bit and 10-bit
- - โ yuv444
- - โ Additional output precisions (12, 14 and 16-bit)
- - โ Output PNG instead of PPM for the decoded images
-- Decoder
- - โ Fast C implementation
- - โ Integer computation for the ARM
- - โ Complete integerization
- - โ Decrease memory footprint & faster decoding
-
-### Latest release: ๐ __Cool-chic 3.3: An even faster decoder!__ ๐
-
-- Make the **CPU-only decoder** even faster.
- - Decode a 720p image in **100 ms**, **2x faster** than Cool-chic 3.2
- - Full **integerization** of the decoder for replicability
- - Reduce decoder **memory footprint**
- - **Optimized** implementation of 3x3 convolutions & fusion of successive 1x1 convolutions
+
+
+
+
+
+_Decoding time are obtained on a single CPU core of an an AMD EPYC 7282 16-Core Processor_
+
+_PSNR is computed in the RGB domain for kodak and CLIC20, in the YUV420 domain for jvet_
+
### Kodak
-
-
-
-
+
### CLIC20 Pro Valid
-
-
-
-
+
### JVET Class B
-
-
-
-
+
+
+
# Thanks
Special thanks go to Hyunjik Kim, Matthias Bauer, Lucas Theis, Jonathan Richard Schwarz and Emilien Dupont for their great work enhancing Cool-chic: [_C3: High-performance and low-complexity neural compression from a single image or video_, Kim et al.](https://arxiv.org/abs/2312.02753)
@@ -154,7 +204,6 @@ Special thanks go to Hyunjik Kim, Matthias Bauer, Lucas Theis, Jonathan Richard
-
#
diff --git a/cfg/dec/hop.cfg b/cfg/dec/hop.cfg
index 7a0731f2..70b39f38 100644
--- a/cfg/dec/hop.cfg
+++ b/cfg/dec/hop.cfg
@@ -1,6 +1,5 @@
-arm = 24,2
-layers_synthesis = 40-1-linear-relu,3-1-linear-none,3-3-residual-relu,3-3-residual-none
+arm = 16,2
+layers_synthesis = 48-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none
n_ft_per_res = 1,1,1,1,1,1,1
-upsampling_kernel_size = 8
-static_upsampling_kernel = False
-
+ups_k_size = 8
+ups_preconcat_k_size = 7
\ No newline at end of file
diff --git a/cfg/dec/lop.cfg b/cfg/dec/lop.cfg
index 1252eaa5..adca16e9 100644
--- a/cfg/dec/lop.cfg
+++ b/cfg/dec/lop.cfg
@@ -1,5 +1,5 @@
arm = 8,2
-layers_synthesis = 16-1-linear-relu,3-1-linear-none,3-3-residual-relu,3-3-residual-none
+layers_synthesis = 16-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none
n_ft_per_res = 1,1,1,1,1,1,1
-upsampling_kernel_size = 4
-static_upsampling_kernel = False
+ups_k_size = 8
+ups_preconcat_k_size = 7
diff --git a/cfg/dec/mop.cfg b/cfg/dec/mop.cfg
index ff6e61de..a124b0cd 100644
--- a/cfg/dec/mop.cfg
+++ b/cfg/dec/mop.cfg
@@ -1,5 +1,5 @@
arm = 16,2
-layers_synthesis = 16-1-linear-relu,3-1-linear-none,3-3-residual-relu,3-3-residual-none
+layers_synthesis = 16-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none
n_ft_per_res = 1,1,1,1,1,1,1
-upsampling_kernel_size = 4
-static_upsampling_kernel = False
+ups_k_size = 8
+ups_preconcat_k_size = 7
\ No newline at end of file
diff --git a/cfg/dec/vlop.cfg b/cfg/dec/vlop.cfg
index 616e5919..b473de55 100644
--- a/cfg/dec/vlop.cfg
+++ b/cfg/dec/vlop.cfg
@@ -1,5 +1,5 @@
arm = 8,1
-layers_synthesis = 8-1-linear-relu,3-1-linear-none,3-3-residual-none
+layers_synthesis = 8-1-linear-relu,X-1-linear-none,X-3-residual-none
n_ft_per_res = 1,1,1,1,1,1,1
-upsampling_kernel_size = 4
-static_upsampling_kernel = False
+ups_k_size = 8
+ups_preconcat_k_size = 7
\ No newline at end of file
diff --git a/coolchic/cpp/CMakeLists.txt b/coolchic/cpp/CMakeLists.txt
new file mode 100644
index 00000000..0d4e2e05
--- /dev/null
+++ b/coolchic/cpp/CMakeLists.txt
@@ -0,0 +1,57 @@
+
+add_executable(ccdec)
+
+set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR})
+
+if(WIN32)
+
+ message(STATUS "[ERROR] Cool-chic decoder not yet implemented for Windows...")
+
+# Check Apple first, then UNIX (Apple + Linux) so that if we enter the UNIX if
+# it means that we're on Linux.
+elseif(APPLE)
+
+ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
+
+ # Changes when compiling for arm64 Apple Mac:
+ # - Remove all *_avx2.cpp and *_avx512.cpp files
+ # - Remove the -mfa from the compilation options
+ # - Remove all the target_link_options... what is this for??
+ #
+ # It only compiles using g++/gcc, not clang which defaults to
+ # an older version apparently?
+ # cmake -DCMAKE_C_COMPILER=/opt/homebrew/bin/gcc-13 -DCMAKE_CXX_COMPILER=/opt/homebrew/bin/g++-13 ..
+
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -g -Wall -Winline")
+
+ target_sources(ccdec PRIVATE ccdecapi.cpp cc-bitstream.cpp cc-contexts.cpp arm_cpu.cpp syn_cpu.cpp BitStream.cpp TDecBinCoderCABAC.cpp Contexts.cpp)
+
+ else()
+
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -g -mfma -Winline")
+
+ # For now, we compile *_avx2.cpp and files, but they are
+ # excluded from ccdec.cpp using quick & dirty #ifdef __APPLE__
+ target_sources(ccdec PRIVATE ccdecapi.cpp cc-bitstream.cpp cc-contexts.cpp arm_cpu.cpp arm_avx2.cpp ups_cpu.cpp ups_avx2.cpp syn_cpu.cpp syn_avx2.cpp BitStream.cpp TDecBinCoderCABAC.cpp Contexts.cpp)
+
+ set_source_files_properties(arm_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
+ set_source_files_properties(ups_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
+ set_source_files_properties(syn_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
+
+ endif()
+
+elseif(UNIX)
+
+ message(STATUS "Architecture: Linux")
+
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -g -mfma -Wall -Winline -DCCDEC_EXE -DCCDECAPI_AVX2_OPTIONAL")
+
+ target_sources(ccdec PRIVATE ccdecapi.cpp cc-bitstream.cpp cc-contexts.cpp cc-frame-decoder.cpp frame-memory.cpp arm_cpu.cpp arm_avx2.cpp ups_cpu.cpp ups_avx2.cpp syn_cpu.cpp syn_avx2.cpp BitStream.cpp TDecBinCoderCABAC.cpp Contexts.cpp)
+ set(CMAKE_EXE_LINKER_FLAGS "-static")
+
+ set_source_files_properties(arm_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
+ set_source_files_properties(ups_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
+ set_source_files_properties(syn_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
+
+endif()
+
diff --git a/coolchic/cpp/cc-bitstream.cpp b/coolchic/cpp/cc-bitstream.cpp
index aceae136..108a39e5 100644
--- a/coolchic/cpp/cc-bitstream.cpp
+++ b/coolchic/cpp/cc-bitstream.cpp
@@ -43,7 +43,7 @@ int read_bitdepth(FILE *bs)
return (raw < 16) ? 8 : 10;
}
-bool cc_bs::open(std::string filename)
+bool cc_bs::open(std::string filename, int verbosity)
{
m_filename = filename;
m_f = fopen(filename.c_str(), "rb");
@@ -52,10 +52,10 @@ bool cc_bs::open(std::string filename)
printf("Cannot open bitstream file %s for reading.\n", filename.c_str());
return false;
}
- return read_gop_header();
+ return read_gop_header(verbosity);
}
-bool cc_bs::read_gop_header()
+bool cc_bs::read_gop_header(int verbosity)
{
int raw;
@@ -63,18 +63,22 @@ bool cc_bs::read_gop_header()
m_gop_header.img_h = read_int_2(m_f);
m_gop_header.img_w = read_int_2(m_f);
raw = read_int_1(m_f);
- m_gop_header.bitdepth = (raw>>4) == 0 ? 8 : 10;
+ //m_gop_header.bitdepth = (raw>>4) == 0 ? 8 : 10;
+ m_gop_header.bitdepth = (raw>>4) + 8; // !!! ASSUMING 8, 9..16 as 1st indicies.
m_gop_header.frame_data_type = raw&0xF;
m_gop_header.intra_period = read_int_1(m_f);
m_gop_header.p_period = read_int_1(m_f);
- printf("GOP HEADER:\n");
- printf(" n_bytes_header: %d\n", m_gop_header.n_bytes_header);
- printf(" img_size: %d,%d\n", m_gop_header.img_h, m_gop_header.img_w);
- printf(" frame_data_type: %d\n", m_gop_header.frame_data_type);
- printf(" bitdepth: %d\n", m_gop_header.bitdepth);
- printf(" intra_period: %d\n", m_gop_header.intra_period);
- printf(" p_period: %d\n", m_gop_header.p_period);
+ if (verbosity >= 2)
+ {
+ printf("GOP HEADER:\n");
+ printf(" n_bytes_header: %d\n", m_gop_header.n_bytes_header);
+ printf(" img_size: %d,%d\n", m_gop_header.img_h, m_gop_header.img_w);
+ printf(" frame_data_type: %d\n", m_gop_header.frame_data_type);
+ printf(" bitdepth: %d\n", m_gop_header.bitdepth);
+ printf(" intra_period: %d\n", m_gop_header.intra_period);
+ printf(" p_period: %d\n", m_gop_header.p_period);
+ }
return true;
}
@@ -129,11 +133,11 @@ void print_lqi(char const *name, struct cc_bs_layer_quant_info &lqi)
void print_syn(struct cc_bs_syn_layer &syn)
{
- printf("%d-%d-%s-%s",
+ printf("out%d-ks%d-%s-%s",
syn.n_out_ft, syn.ks, syn.residual ? "residual" : "linear", syn.relu ? "relu" : "none");
}
-bool read_frame_header(FILE *bs, struct cc_bs_frame_header &frame_header)
+bool read_frame_header(FILE *bs, struct cc_bs_frame_header &frame_header, int verbosity)
{
int raw;
@@ -142,10 +146,15 @@ bool read_frame_header(FILE *bs, struct cc_bs_frame_header &frame_header)
raw = read_int_1(bs);
frame_header.dim_arm = 8*(raw>>4);
frame_header.n_hidden_layers_arm = raw&0xF;
+
+ raw = read_int_1(bs);
+ frame_header.n_ups_kernel = raw>>4;
+ frame_header.ups_k_size = raw&0xF;
raw = read_int_1(bs);
- frame_header.upsampling_kern_size = raw>>1;
- frame_header.static_upsampling_kernel = raw&1;
+ frame_header.n_ups_preconcat_kernel = raw>>4;
+ frame_header.ups_preconcat_k_size = raw&0xF;
+ frame_header.n_syn_branches = read_int_1(bs);
frame_header.n_syn_layers = read_int_1(bs);
frame_header.layers_synthesis.resize(frame_header.n_syn_layers);
for (int i = 0; i < frame_header.n_syn_layers; i++)
@@ -181,39 +190,45 @@ bool read_frame_header(FILE *bs, struct cc_bs_frame_header &frame_header)
for (int i = 0; i < frame_header.latent_n_2d_grid; i++)
frame_header.n_bytes_per_latent[i] = read_int_3(bs);
- printf("FRAME HEADER\n");
- printf(" n_bytes_header: %d\n", frame_header.n_bytes_header);
- printf(" latent_n_resolutions: %d\n", frame_header.n_latent_n_resolutions);
- printf(" latent_n_2d_grid: %d\n", frame_header.latent_n_2d_grid);
- printf(" n_bytes_per_latent:");
- for (int i = 0; i < frame_header.latent_n_2d_grid; i++)
- printf(" %d", frame_header.n_bytes_per_latent[i]);
- printf("\n");
- printf(" n_ft_per_latent:");
- for (int i = 0; i < frame_header.n_latent_n_resolutions; i++)
- printf(" %d", frame_header.n_ft_per_latent[i]);
- printf("\n");
- printf(" n_hidden_layers_arm: %d\n", frame_header.n_hidden_layers_arm);
- printf(" dim_arm: %d\n", frame_header.dim_arm);
- printf(" upsampling_kernel_size: %d\n", frame_header.upsampling_kern_size);
- printf(" static_upsampling_kernel: %d\n", frame_header.static_upsampling_kernel);
- printf(" flow_gain: %d\n", frame_header.flow_gain);
- printf(" layers_synthesis:");
- for (int i = 0; i < frame_header.n_syn_layers; i++)
+ if (verbosity >= 2)
{
- printf(" ");
- print_syn(frame_header.layers_synthesis[i]);
+ printf("FRAME HEADER\n");
+ printf(" n_bytes_header: %d\n", frame_header.n_bytes_header);
+ printf(" display_index: %d\n", frame_header.display_index);
+ printf(" dim_arm: %d\n", frame_header.dim_arm);
+ printf(" n_hidden_layers_arm: %d\n", frame_header.n_hidden_layers_arm);
+ printf(" n_ups_kernel=%d, ups_ks=%d\n", frame_header.n_ups_kernel, frame_header.ups_k_size);
+ printf(" n_ups_preconcat_kernel=%d, ups_preconcat_ks=%d\n", frame_header.n_ups_preconcat_kernel, frame_header.ups_preconcat_k_size);
+ printf(" n_syn_branches: %d\n", frame_header.n_syn_branches);
+ printf(" layers_synthesis:");
+ for (int i = 0; i < frame_header.n_syn_layers; i++)
+ {
+ printf(" ");
+ print_syn(frame_header.layers_synthesis[i]);
+ }
+ printf("\n");
+
+ printf(" flow_gain: %d\n", frame_header.flow_gain);
+ printf(" ac_max_val_nn: %d\n", frame_header.ac_max_val_nn);
+ printf(" ac_max_val_latent: %d\n", frame_header.ac_max_val_latent);
+ printf(" hls_sig_blksize: %d\n", frame_header.hls_sig_blksize);
+
+ print_lqi("arm", frame_header.arm_lqi);
+ print_lqi("ups", frame_header.ups_lqi);
+ print_lqi("syn", frame_header.syn_lqi);
+
+ printf(" latent_n_resolutions: %d\n", frame_header.n_latent_n_resolutions);
+ printf(" latent_n_2d_grid: %d\n", frame_header.latent_n_2d_grid);
+
+ printf(" n_ft_per_latent:");
+ for (int i = 0; i < frame_header.n_latent_n_resolutions; i++)
+ printf(" %d", frame_header.n_ft_per_latent[i]);
+ printf("\n");
+ printf(" n_bytes_per_latent:");
+ for (int i = 0; i < frame_header.latent_n_2d_grid; i++)
+ printf(" %d", frame_header.n_bytes_per_latent[i]);
+ printf("\n");
}
- printf("\n");
-
- print_lqi("arm", frame_header.arm_lqi);
- print_lqi("ups", frame_header.ups_lqi);
- print_lqi("syn", frame_header.syn_lqi);
- printf(" ac_max_val_nn: %d\n", frame_header.ac_max_val_nn);
- printf(" ac_max_val_latent: %d\n", frame_header.ac_max_val_latent);
- printf(" hls_sig_blksize: %d\n", frame_header.hls_sig_blksize);
- printf(" display_index: %d\n", frame_header.display_index);
-
return true;
}
@@ -231,10 +246,10 @@ std::vector get_coded(FILE *bs, int n_bytes)
return result;
}
-struct cc_bs_frame *cc_bs::decode_frame()
+struct cc_bs_frame *cc_bs::decode_frame(int verbosity)
{
cc_bs_frame *result = new cc_bs_frame();
- if (!read_frame_header(m_f, result->m_frame_header))
+ if (!read_frame_header(m_f, result->m_frame_header, verbosity))
{
delete result;
return NULL;
@@ -244,16 +259,14 @@ struct cc_bs_frame *cc_bs::decode_frame()
result->m_arm_weights_hevc = get_coded(m_f, frame_header.arm_lqi.n_bytes_nn_weight);
result->m_arm_biases_hevc = get_coded(m_f, frame_header.arm_lqi.n_bytes_nn_bias);
result->m_ups_weights_hevc = get_coded(m_f, frame_header.ups_lqi.n_bytes_nn_weight);
-
- std::vector dummy_output_bias_ups(frame_header.ups_lqi.n_bytes_nn_bias);
- dummy_output_bias_ups = get_coded(m_f, frame_header.ups_lqi.n_bytes_nn_bias);
-
+ result->m_ups_biases_hevc = get_coded(m_f, frame_header.ups_lqi.n_bytes_nn_bias);
result->m_syn_weights_hevc = get_coded(m_f, frame_header.syn_lqi.n_bytes_nn_weight);
result->m_syn_biases_hevc = get_coded(m_f, frame_header.syn_lqi.n_bytes_nn_bias);
- //printf("arm coded w%d b%d bytes w:%02x %02x %02x\n", result->m_arm_weights_hevc.size(), result->m_arm_biases_hevc.size(), result->m_arm_weights_hevc[0], result->m_arm_weights_hevc[1], result->m_arm_weights_hevc[2]);
- //printf("ups coded w%d bytes w %02x %02x %02x\n", result->m_ups_weights_hevc.size(), result->m_ups_weights_hevc[0], result->m_ups_weights_hevc[1], result->m_ups_weights_hevc[2]);
- //printf("syn coded w%d b%d bytes w %02x %02x %02x\n", result->m_syn_weights_hevc.size(), result->m_syn_biases_hevc.size(), result->m_syn_weights_hevc[0], result->m_syn_weights_hevc[1], result->m_syn_weights_hevc[2]);
+ //printf("arm coded w%ld b%ld bytes w:%02x %02x %02x\n", result->m_arm_weights_hevc.size(), result->m_arm_biases_hevc.size(), result->m_arm_weights_hevc[0], result->m_arm_weights_hevc[1], result->m_arm_weights_hevc[2]);
+ //printf("ups coded w%ld b%ld bytes w:%02x %02x %02x\n", result->m_ups_weights_hevc.size(), result->m_ups_biases_hevc.size(), result->m_ups_weights_hevc[0], result->m_ups_weights_hevc[1], result->m_ups_weights_hevc[2]);
+ //printf("syn coded w%ld b%ld bytes w %02x %02x %02x\n", result->m_syn_weights_hevc.size(), result->m_syn_biases_hevc.size(), result->m_syn_weights_hevc[0], result->m_syn_weights_hevc[1], result->m_syn_weights_hevc[2]);
+ //fflush(stdout);
for (int i = 0; i < frame_header.n_latent_n_resolutions; i++)
result->m_latents_hevc.emplace_back(get_coded(m_f, frame_header.n_bytes_per_latent[i]));
diff --git a/coolchic/cpp/cc-bitstream.h b/coolchic/cpp/cc-bitstream.h
index 82de2b86..712a2ac5 100644
--- a/coolchic/cpp/cc-bitstream.h
+++ b/coolchic/cpp/cc-bitstream.h
@@ -50,9 +50,14 @@ struct cc_bs_frame_header
std::vector n_ft_per_latent;
int n_hidden_layers_arm;
int dim_arm;
- int upsampling_kern_size;
- int static_upsampling_kernel;
+
+ int n_ups_kernel;
+ int ups_k_size;
+ int n_ups_preconcat_kernel;
+ int ups_preconcat_k_size;
+
int flow_gain;
+ int n_syn_branches;
int n_syn_layers;
std::vector layers_synthesis;
struct cc_bs_layer_quant_info arm_lqi;
@@ -66,16 +71,14 @@ struct cc_bs_frame_header
int hls_sig_blksize;
};
-//bool read_gop_header(FILE *bitstream, struct bitstream_gop_header &gop_header);
-//bool read_frame_header(FILE *bitstream, struct bitstream_frame_header &frame_header);
-
struct cc_bs_frame {
struct cc_bs_frame_header m_frame_header;
std::vector m_arm_weights_hevc;
std::vector m_arm_biases_hevc;
- std::vector m_ups_weights_hevc;
- std::vector m_syn_weights_hevc;
- std::vector m_syn_biases_hevc;
+ std::vector m_ups_weights_hevc; // both lb and hb
+ std::vector m_ups_biases_hevc; // both lb and hb
+ std::vector m_syn_weights_hevc; // all branches
+ std::vector m_syn_biases_hevc; // all branches
std::vector> m_latents_hevc;
};
@@ -83,10 +86,10 @@ class cc_bs {
public:
cc_bs() { m_f = NULL; }
~cc_bs() { if (m_f != NULL) fclose(m_f); }
- bool open(std::string filename);
- struct cc_bs_frame *decode_frame();
+ bool open(std::string filename, int verbosity = 0);
+ struct cc_bs_frame *decode_frame(int verbosity = 0);
private:
- bool read_gop_header();
+ bool read_gop_header(int verbosity = 0);
public:
std::string m_filename;
diff --git a/coolchic/cpp/cc-frame-decoder.cpp b/coolchic/cpp/cc-frame-decoder.cpp
index d19bc353..5190d2f5 100644
--- a/coolchic/cpp/cc-frame-decoder.cpp
+++ b/coolchic/cpp/cc-frame-decoder.cpp
@@ -1,4 +1,8 @@
+#ifdef CCDECAPI_AVX2_OPTIONAL
+#define CCDECAPI_AVX2
+#endif
+
#include "common.h"
#include "TDecBinCoderCABAC.h"
#include "cc-contexts.h"
@@ -13,11 +17,13 @@
#endif
#include "arm_cpu.h"
+#include "ups_cpu.h"
#include "syn_cpu.h"
extern float time_arm_seconds;
extern float time_ups_seconds;
extern float time_syn_seconds;
+extern float time_blend_seconds;
int const Q_STEP_ARM_WEIGHT_SHIFT[] = {
8, // 1.0/(1<<8),
@@ -162,6 +168,11 @@ void decode_weights_qi(int32_t *result, TDecBinCABAC *cabac, int count, int n_we
// want val << precision >> q_step_shift
if (precision > q_step_shift)
tmp <<= (precision-q_step_shift);
+ else if (precision < q_step_shift)
+ {
+ printf("decoding weights: qstepshift %d > precision %d\n", q_step_shift, precision);
+ exit(1);
+ }
result[i] = tmp;
}
@@ -173,61 +184,18 @@ void decode_weights_qi(weights_biases &result, TDecBinCABAC *cabac, int count, i
decode_weights_qi(result.data, cabac, count, n_weights, q_step_shift, precision);
}
-// applying 4 small kernels, ks here is upsampling_kernel/2
-//void custom_conv_ups_zxz1_cpu(int ks, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int32_t *out, int h_target, int w_target, int stride_out)
-void custom_conv_ups_zxz1_cpu(int ks, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int stride_out, int32_t *out, int ups_mul_precision)
-{
- int const kstride = 1;
- int offs0 = 0;
-
- for (int y = 0; y < h_in-ks+1; y += kstride, offs0 += stride_in)
- {
- for (int vf = 0; vf < 2; vf++, out += stride_out-(w_in-ks+1)*2) // vertical filter choice on this line.
- {
- int offs = offs0;
- for (int x = 0; x < w_in-ks+1; x += kstride, offs += kstride)
- {
- for (int hf = 0; hf < 2; hf++) // horizontal filter choice at this point
- {
- int32_t sum = 0;
- int32_t *k = kw+(vf*2+hf)*(ks*ks);
- int offs2 = offs;
- for (int yy = 0; yy < ks; yy++, offs2 += stride_in-ks)
- for (int xx = 0; xx < ks; xx++)
- {
- sum += in[offs2++]*(*k++);
- }
- *out++ = sum >> ups_mul_precision;
- }
- }
- }
- }
-}
-
-// upsamplingx2: applies 4 smaller kernels sized [ks/2][ks/2]
-// first_ups implies we are processing ARM output (.8) rather than upsampling output (.12).
-void custom_upsample_4(int ks, int32_t *weightsx4xpose, int h_in, int w_in, int stride_in, int32_t *in, int stride_target, int32_t *out, bool first_ups)
+// We take a kernel size ks, but only read (ks+1)/2 weights. These weights are mirrored to produce the full kernel.
+void decode_upsweights_qi(weights_biases &result, TDecBinCABAC *cabac, int count, int ks, int q_step_shift, int precision)
{
-#ifdef CCDECAPI_AVX2
- if (ks == 8)
- {
- if (first_ups)
- ups_4x4x4_fromarm_avx2(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out);
- else
- ups_4x4x4_fromups_avx2(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out);
- }
- else if (ks == 4)
+ int nw = (ks+1)/2; // number of weights to read.
+ result.update_to(ks); // we allocate to allow mirroring.
+ decode_weights_qi(result.data, cabac, count, nw, q_step_shift, precision); // read the 1st half.
+ // mirror last half
+ for (int i = 0; i < nw/2*2; i++)
{
- if (first_ups)
- ups_4x2x2_fromarm_avx2(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out);
- else
- ups_4x2x2_fromups_avx2(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out);
- }
- else
-#endif
- {
- custom_conv_ups_zxz1_cpu(ks/2, weightsx4xpose, h_in, w_in, stride_in, in, stride_target, out, first_ups ? ARM_PRECISION : UPS_PRECISION);
+ result.data[ks-1-i] = result.data[i];
}
+ fflush(stdout);
}
void cc_frame_decoder::read_arm(struct cc_bs_frame *frame_symbols)
@@ -287,149 +255,58 @@ void cc_frame_decoder::read_ups(struct cc_bs_frame *frame_symbols)
{
struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header;
- int const ups_ks = frame_header.upsampling_kern_size;
- int32_t bucket[ups_ks*ups_ks]; // we transpose from this.
-
- if (frame_header.static_upsampling_kernel != 0)
+ int n_ups = frame_header.n_ups_kernel;
+ int n_ups_preconcat = frame_header.n_ups_preconcat_kernel;
+ if (n_ups != m_ups_n)
{
-#if 0 // generate text.
- float bilinear_kernel[4 * 4] = {
- 0.0625, 0.1875, 0.1875, 0.0625,
- 0.1875, 0.5625, 0.5625, 0.1875,
- 0.1875, 0.5625, 0.5625, 0.1875,
- 0.0625, 0.1875, 0.1875, 0.0625
- };
-
- float bicubic_kernel[8 * 8] = {
- 0.0012359619 , 0.0037078857 ,-0.0092010498 ,-0.0308990479 ,-0.0308990479 ,-0.0092010498 , 0.0037078857 , 0.0012359619,
- 0.0037078857 , 0.0111236572 ,-0.0276031494 ,-0.0926971436 ,-0.0926971436 ,-0.0276031494 , 0.0111236572 , 0.0037078857,
- -0.0092010498 ,-0.0276031494 , 0.0684967041 , 0.2300262451 , 0.2300262451 , 0.0684967041 ,-0.0276031494 ,-0.0092010498,
- -0.0308990479 ,-0.0926971436 , 0.2300262451 , 0.7724761963 , 0.7724761963 , 0.2300262451 ,-0.0926971436 ,-0.0308990479,
- -0.0308990479 ,-0.0926971436 , 0.2300262451 , 0.7724761963 , 0.7724761963 , 0.2300262451 ,-0.0926971436 ,-0.0308990479,
- -0.0092010498 ,-0.0276031494 , 0.0684967041 , 0.2300262451 , 0.2300262451 , 0.0684967041 ,-0.0276031494 ,-0.0092010498,
- 0.0037078857 , 0.0111236572 ,-0.0276031494 ,-0.0926971436 ,-0.0926971436 ,-0.0276031494 , 0.0111236572 , 0.0037078857,
- 0.0012359619 , 0.0037078857 ,-0.0092010498 ,-0.0308990479 ,-0.0308990479 ,-0.0092010498 , 0.0037078857 , 0.0012359619
- };
- // generate text.
- printf("int32_t bilinear_kernel[4 * 4] = {\n");
- for (int ky = 0; ky < 4; ky++)
- {
- printf(" ");
- for (int kx = 0; kx < 4; kx++)
- printf("%d%s", (int32_t)(bilinear_kernel[ky*4+kx]*(1< &bs_fifo_weights = bs_weights.getFifo();
-
- TDecBinCABAC cabac_weights;
+ delete[] m_upsw_preconcat;
+ m_upsw_preconcat = new weights_biases[n_ups_preconcat];
+ m_ups_n_preconcat = n_ups_preconcat;
+ }
- bs_fifo_weights = frame_symbols->m_ups_weights_hevc;
- cabac_weights.init(&bs_weights);
- cabac_weights.start();
- struct cc_bs_layer_quant_info &ups_lqi = frame_header.ups_lqi;
+ InputBitstream bs_weights;
+ std::vector &bs_fifo_weights = bs_weights.getFifo();
- int q_step_w_shift = Q_STEP_UPS_SHIFT[ups_lqi.q_step_index_nn_weight];
+ TDecBinCABAC cabac_weights;
- decode_weights_qi(&bucket[0], &cabac_weights, ups_lqi.scale_index_nn_weight, frame_header.upsampling_kern_size*frame_header.upsampling_kern_size, q_step_w_shift, UPS_PRECISION);
- }
+ bs_fifo_weights = frame_symbols->m_ups_weights_hevc;
+ cabac_weights.init(&bs_weights);
+ cabac_weights.start();
+ struct cc_bs_layer_quant_info &ups_lqi = frame_header.ups_lqi;
- // transpose.
- int32_t bucket_t[ups_ks*ups_ks];
+ int q_step_w_shift = Q_STEP_UPS_SHIFT[ups_lqi.q_step_index_nn_weight];
- int idx = 0;
- for (int y = 0; y < ups_ks; y++)
+ // read ups layers
+ for (int lidx = 0; lidx < m_ups_n; lidx++)
{
- for (int x = 0; x < ups_ks; x++, idx++)
- {
- bucket_t[idx] = bucket[(ups_ks-1-y)*ups_ks+(ups_ks-1-x)];
- }
+ decode_upsweights_qi(m_upsw[lidx], &cabac_weights, ups_lqi.scale_index_nn_weight, frame_header.ups_k_size, q_step_w_shift, UPS_PRECISION);
}
-
- // extract 4 smaller filters. 8x8 -> 4x 4x4
- // f0 operates at origin, y=0, x=0
- // f1 operates at y=0, x=1
- // f2 operates at y=1, x=0
- // f3 operates at y=1, x=1
- m_upsw_t_4x4.update_to(ups_ks*ups_ks);
- int ups_ks2 = ups_ks/2;
-
- for (int f = 0; f < 4; f++)
+ // read ups_preconcat layers.
+ for (int lidx = 0; lidx < m_ups_n_preconcat; lidx++)
{
- int fbase_y = 1-f/2;
- int fbase_x = 1-f%2;
- for (int y = 0; y < ups_ks/2; y++)
- {
- for (int x = 0; x < ups_ks/2; x++)
- {
- m_upsw_t_4x4.data[f*ups_ks2*ups_ks2 + y*ups_ks2 + x] = bucket_t[(fbase_y+2*y)*ups_ks+fbase_x+2*x];
- }
- }
+ decode_upsweights_qi(m_upsw_preconcat[lidx], &cabac_weights, ups_lqi.scale_index_nn_weight, frame_header.ups_preconcat_k_size, q_step_w_shift, UPS_PRECISION);
}
-
}
void cc_frame_decoder::read_syn(struct cc_bs_frame *frame_symbols)
{
struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header;
- if (frame_header.n_syn_layers != m_syn_n_layers)
+ if (frame_header.n_syn_layers != m_syn_n_layers || frame_header.n_syn_branches != m_syn_n_branches)
{
- m_syn_n_layers = frame_header.n_syn_layers;
delete [] m_synw;
delete [] m_synb;
- m_synw = new weights_biases[m_syn_n_layers];
- m_synb = new weights_biases[m_syn_n_layers];
+ m_syn_n_branches = frame_header.n_syn_branches;
+ m_syn_n_layers = frame_header.n_syn_layers;
+ m_syn_blends.update_to(m_syn_n_layers);
+ m_synw = new weights_biases[m_syn_n_branches*m_syn_n_layers];
+ m_synb = new weights_biases[m_syn_n_branches*m_syn_n_layers];
}
InputBitstream bs_weights;
@@ -451,38 +328,39 @@ void cc_frame_decoder::read_syn(struct cc_bs_frame *frame_symbols)
int q_step_w_shift = Q_STEP_SYN_WEIGHT_SHIFT[syn_lqi.q_step_index_nn_weight];
int q_step_b_shift = Q_STEP_SYN_BIAS_SHIFT[syn_lqi.q_step_index_nn_bias];
- int n_in_ft = frame_header.n_latent_n_resolutions; // !!! features per layer
- for (int idx = 0; idx < frame_header.n_syn_layers; idx++)
+ // blend values.
+ if (m_syn_n_branches > 1)
{
- struct cc_bs_syn_layer &syn = frame_header.layers_synthesis[idx];
- int n_weights = n_in_ft * syn.ks*syn.ks * syn.n_out_ft;
- int n_biases = syn.n_out_ft;
+ decode_weights_qi(m_syn_blends, &cabac_weights, syn_lqi.scale_index_nn_weight, m_syn_n_branches, q_step_w_shift, SYN_WEIGHT_PRECISION);
+ }
- decode_weights_qi(m_synw[idx], &cabac_weights, syn_lqi.scale_index_nn_weight, n_weights, q_step_w_shift, SYN_WEIGHT_PRECISION);
- decode_weights_qi(m_synb[idx], &cabac_biases, syn_lqi.scale_index_nn_bias, n_biases, q_step_b_shift, SYN_WEIGHT_PRECISION*2);
+ // layer weights.
+ for (int bidx = 0; bidx < frame_header.n_syn_branches; bidx++)
+ {
+ int n_in_ft = frame_header.n_latent_n_resolutions; // !!! features per layer
+ for (int lidx = 0; lidx < frame_header.n_syn_layers; lidx++)
+ {
+ struct cc_bs_syn_layer &syn = frame_header.layers_synthesis[lidx];
+ int n_weights = n_in_ft * syn.ks*syn.ks * syn.n_out_ft;
+ int n_biases = syn.n_out_ft;
- n_in_ft = syn.n_out_ft;
+ decode_weights_qi(m_synw[get_syn_idx(bidx, lidx)], &cabac_weights, syn_lqi.scale_index_nn_weight, n_weights, q_step_w_shift, SYN_WEIGHT_PRECISION);
+ decode_weights_qi(m_synb[get_syn_idx(bidx, lidx)], &cabac_biases, syn_lqi.scale_index_nn_bias, n_biases, q_step_b_shift, SYN_WEIGHT_PRECISION*2);
+
+ n_in_ft = syn.n_out_ft;
+ }
}
}
// check to see if the leading couple of syns are compatible with being fused
// together.
-// 8, 16, 40 are checked. 8 is marginal (base mem allocation is 7, 8 is not much more).
+// Memory increase avoidance is so important we use cpu-mode in avx2 if it's
+// not otherwise optimized.
bool cc_frame_decoder::can_fuse(struct cc_bs_frame *frame_symbols)
{
auto &synthesis = frame_symbols->m_frame_header.layers_synthesis;
bool fused = synthesis[0].ks == 1 && synthesis[1].ks == 1;
- if (!fused)
- return false;
-
- // Check for compatible numbers.
- int n_syn_in = frame_symbols->m_frame_header.n_latent_n_resolutions;
- int n_hidden = synthesis[0].n_out_ft;
- int n_syn_out = synthesis[1].n_out_ft;
- fused = (n_syn_in == 7)
- && (n_hidden == 8 || n_hidden == 16 || n_hidden == 40)
- && (n_syn_out == 3 || n_syn_out == 6 || n_syn_out == 9);
return fused;
}
@@ -490,16 +368,13 @@ void cc_frame_decoder::check_allocations(struct cc_bs_frame *frame_symbols)
{
struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header;
- // we allocate a max-sized (ignoring padding) buffer, suitable for holding
- // upsample final output and synthesis outputs during syn processing.
int const latent_n_resolutions = frame_header.n_latent_n_resolutions;
int n_max_planes = latent_n_resolutions;
m_arm_pad = 4; // by how much the arm frame is padded.
- m_ups_pad = 8/2+1; // kernel size 8 upsampling.
- m_max_pad = std::max(m_arm_pad, m_ups_pad); // during arm, ups and syn. // !!! we use 5 for max ups (5) (ks=8)/2+1, and arm (4) (cxt_row_col)
+ m_ups_pad = (std::max(frame_header.ups_k_size, frame_header.ups_preconcat_k_size)+1)/2;
- bool need_ups_2 = false; // we only need this if we cannot do in-place convolution. Ie, ks 1 or 3 are in-place, otherwise not.
+ int syn_max_ks = 1;
for (int syn_idx = 0; syn_idx < frame_header.n_syn_layers; syn_idx++)
{
if (syn_idx == 0 && can_fuse(frame_symbols))
@@ -510,24 +385,27 @@ void cc_frame_decoder::check_allocations(struct cc_bs_frame *frame_symbols)
int n_pad = frame_header.layers_synthesis[syn_idx].ks/2;
if (n_pad > m_max_pad)
m_max_pad = n_pad;
- if (frame_header.layers_synthesis[syn_idx].ks != 1 && frame_header.layers_synthesis[syn_idx].ks != 3)
- need_ups_2 = true;
+ syn_max_ks = std::max(syn_max_ks, frame_header.layers_synthesis[syn_idx].ks);
}
- printf("MAX PLANES SET TO %d; MAX PAD SET TO %d\n", n_max_planes, m_max_pad);
- m_ups_1.update_to(m_gop_header.img_h, m_gop_header.img_w, n_max_planes, m_max_pad);
- // we do not (normally!) need m_ups_2.
- if (need_ups_2)
- {
- printf("ups_2 for syn allocated\n");
- m_ups_2.update_to(m_gop_header.img_h, m_gop_header.img_w, n_max_planes, m_max_pad);
- }
+ int syn_pad = (syn_max_ks+1)/2;
+ m_max_pad = std::max(std::max(m_arm_pad, m_ups_pad), syn_pad); // during arm, ups and syn.
+
+ if (m_verbosity >= 3)
+ printf("MAX PLANES SET TO %d; MAX PAD SET TO %d\n", n_max_planes, m_max_pad);
+
+ // !!! allocate all for the moment for bisyn.
+ m_syn_1.update_to(m_gop_header.img_h, m_gop_header.img_w, m_max_pad, n_max_planes);
+ m_syn_2.update_to(m_gop_header.img_h, m_gop_header.img_w, m_max_pad, n_max_planes);
+ m_syn_3.update_to(m_gop_header.img_h, m_gop_header.img_w, m_max_pad, n_max_planes);
+ m_syn_tmp.update_to(m_gop_header.img_h, m_gop_header.img_w, m_max_pad, n_max_planes);
// pyramid sizes.
m_zero_layer.resize(frame_header.n_latent_n_resolutions);
m_h_pyramid.resize(frame_header.n_latent_n_resolutions);
m_w_pyramid.resize(frame_header.n_latent_n_resolutions);
- // pyramid storage -- enough pad for arm and ups. no highest res layer.
+
+ // pyramid storage -- enough pad for arm and ups. includes highest res layer.
m_plane_pyramid.resize(frame_header.n_latent_n_resolutions);
for (int layer_number = 0, h_grid = m_gop_header.img_h, w_grid = m_gop_header.img_w;
@@ -536,15 +414,52 @@ void cc_frame_decoder::check_allocations(struct cc_bs_frame *frame_symbols)
{
m_h_pyramid[layer_number] = h_grid;
m_w_pyramid[layer_number] = w_grid;
-
- // we do not allocate the full-size.
- if (layer_number == 0)
- m_plane_pyramid[layer_number].update_to(0, 0, 0, 0); // not used.
- else
- m_plane_pyramid[layer_number].update_to(h_grid, w_grid, 1, m_max_pad);
+ m_plane_pyramid[layer_number].update_to(h_grid, w_grid, m_max_pad, 1);
}
+
+ // We need a couple of extra buffers for pyramid upsampling and refinement. It's not in-place.
+ // 1/4 sized for internal refinement target. (Full-sized goes directly to syn-land.)
+ m_ups_h2w2.update_to(m_h_pyramid[1], m_w_pyramid[1], m_ups_pad, 1);
+ // full-height, full-width for actual upsample.
+ m_ups_hw.update_to(m_h_pyramid[0], m_w_pyramid[0], m_ups_pad, 1);
}
+void ups_refine_cpu(int ks_param, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp)
+{
+ if (ks_param == 7)
+ ups_refine_ks7_cpu(ks_param, kw, in, out, ups_src_precision, tmp);
+ else
+ ups_refine_ksX_cpu(ks_param, kw, in, out, ups_src_precision, tmp);
+}
+
+void ups_upsample_cpu(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp)
+{
+ if (ksx2 == 8)
+ ups_upsample_ks8_cpu(ksx2, kw, in, out, out_plane, ups_src_precision, tmp);
+ else
+ ups_upsample_ksX_cpu(ksx2, kw, in, out, out_plane, ups_src_precision, tmp);
+}
+
+#ifdef CCDECAPI_AVX2
+void ups_refine_avx2(int ks_param, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp)
+{
+ if (ks_param == 7)
+ ups_refine_ks7_avx2(ks_param, kw, in, out, ups_src_precision, tmp);
+ else
+ ups_refine_ksX_avx2(ks_param, kw, in, out, ups_src_precision, tmp);
+}
+
+void ups_upsample_avx2(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp)
+{
+ if (ksx2 == 8 && ups_src_precision == ARM_PRECISION)
+ ups_upsample_ks8_ARMPREC_avx2(ksx2, kw, in, out, out_plane, ups_src_precision, tmp);
+ else if (ksx2 == 8 && ups_src_precision == UPS_PRECISION)
+ ups_upsample_ks8_UPSPREC_avx2(ksx2, kw, in, out, out_plane, ups_src_precision, tmp);
+ else
+ ups_upsample_ksX_avx2(ksx2, kw, in, out, out_plane, ups_src_precision, tmp);
+}
+#endif
+
void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols)
{
struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header;
@@ -556,7 +471,8 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols)
int const h_grid = m_h_pyramid[layer_number];
int const w_grid = m_w_pyramid[layer_number];
- printf("starting layer %d\n", layer_number);
+ if (m_verbosity >= 3)
+ printf("arm:starting layer %d\n", layer_number);
fflush(stdout);
auto &bytes = frame_symbols->m_latents_hevc[layer_number];
@@ -568,9 +484,9 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols)
}
m_zero_layer[layer_number] = false;
- frame_memory *dest = layer_number == 0 ? &m_ups_1 : &m_plane_pyramid[layer_number];
+ //frame_memory *dest = layer_number == 0 ? &m_syn_1 : &m_plane_pyramid[layer_number];
+ frame_memory *dest = &m_plane_pyramid[layer_number];
dest->zero_pad(0, m_arm_pad);
- //dest->print_ranges("armin", 1, ARM_PRECISION);
// BAC decoding:
InputBitstream bsBAC;
@@ -587,7 +503,7 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols)
const auto time_arm_start = std::chrono::steady_clock::now();
#ifdef CCDECAPI_AVX2
- if (frame_header.dim_arm == 8)
+ if (m_use_avx2 && frame_header.dim_arm == 8)
custom_conv_11_int32_avx2_8_X_X(
m_mlpw_t, m_mlpb,
&m_mlpwOUT, &m_mlpbOUT,
@@ -596,7 +512,7 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols)
h_grid, w_grid, (dest->stride-w_grid)/2,
bac_context
);
- else if (frame_header.dim_arm == 16)
+ else if (m_use_avx2 && frame_header.dim_arm == 16)
custom_conv_11_int32_avx2_16_X_X(
m_mlpw_t, m_mlpb,
&m_mlpwOUT, &m_mlpbOUT,
@@ -605,7 +521,7 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols)
h_grid, w_grid, (dest->stride-w_grid)/2,
bac_context
);
- else if (frame_header.dim_arm == 24)
+ else if (m_use_avx2 && frame_header.dim_arm == 24)
custom_conv_11_int32_avx2_24_X_X(
m_mlpw_t, m_mlpb,
&m_mlpwOUT, &m_mlpbOUT,
@@ -614,7 +530,7 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols)
h_grid, w_grid, (dest->stride-w_grid)/2,
bac_context
);
- else if (frame_header.dim_arm == 32)
+ else if (m_use_avx2 && frame_header.dim_arm == 32)
custom_conv_11_int32_avx2_32_X_X(
m_mlpw_t, m_mlpb,
&m_mlpwOUT, &m_mlpbOUT,
@@ -636,46 +552,30 @@ void cc_frame_decoder::run_arm(struct cc_bs_frame *frame_symbols)
bac_context
);
- //dest->print_ranges("armout", 1, ARM_PRECISION);
-
-#if UPS_PRECISION > ARM_PRECISION
- if (layer_number == 0)
- {
- // in-place precision change for high-resolution layer.
- // upsampling is not used here, which normally changes the .8 to .12.
- // we do it by hand.
- int32_t *src = dest->plane_origin(0);
- int const lshift = UPS_PRECISION-ARM_PRECISION;
- for (int y = 0; y < h_grid; y++, src += dest->stride-w_grid)
- {
- for (int x = 0; x < w_grid; x++)
- {
- *src++ <<= lshift;
- }
- }
-#else
- what system is this
-#endif
- }
-
const auto time_arm_done = std::chrono::steady_clock::now();
const std::chrono::duration arm_elapsed = (time_arm_done-time_arm_start);
time_arm_seconds += (float)arm_elapsed.count();
} // layer.
- printf("arm done!\n");
+ if (m_verbosity >= 100)
+ {
+ printf("ARM OUTPUTS\n");
+ for (int p = 0; p < frame_header.n_latent_n_resolutions; p++)
+ {
+ printf("ARMPYRAMID %d ", p);
+ m_plane_pyramid[p].print_start(0, "ARM", -1, ARM_PRECISION);
+ }
+ }
}
void cc_frame_decoder::run_ups(struct cc_bs_frame *frame_symbols)
{
struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header;
- // NEW UPSAMPLE
const auto time_ups_start = std::chrono::steady_clock::now();
- int ups_ks = frame_header.upsampling_kern_size;
-
+ // full-res down to lowest-res. refine & upsample each as necessary to full res.
for (int layer_number = 0, h_grid = m_gop_header.img_h, w_grid = m_gop_header.img_w;
layer_number < frame_header.n_latent_n_resolutions;
layer_number++, h_grid = (h_grid+1)/2, w_grid = (w_grid+1)/2)
@@ -683,38 +583,93 @@ void cc_frame_decoder::run_ups(struct cc_bs_frame *frame_symbols)
if (m_zero_layer[layer_number])
{
// no need to upsample. just zero the final content.
- m_ups_1.zero_plane_content(layer_number);
+ m_syn_1.zero_plane_content(layer_number);
continue;
}
+ // layer_number 0: hb_layer is (nlatents-1)%nhblayers.
+ // n_resolution-2 is number of hb layers max.
+ int preconcat_layer = (frame_header.n_latent_n_resolutions-2-layer_number)%m_ups_n_preconcat;
if (layer_number == 0)
{
- // full res. if non-null, already present in m_ups_layers[0].
+ // full res. just a refinement directly to m_syn_1, no upsampling.
+#ifdef CCDECAPI_AVX2
+ if (m_use_avx2)
+ {
+ ups_refine_avx2(frame_header.ups_preconcat_k_size, m_upsw_preconcat[preconcat_layer].data, m_plane_pyramid[0], m_syn_1, ARM_PRECISION, m_ups_hw);
+ }
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ else
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
+ {
+ ups_refine_cpu(frame_header.ups_preconcat_k_size, m_upsw_preconcat[preconcat_layer].data, m_plane_pyramid[0], m_syn_1, ARM_PRECISION, m_ups_hw);
+ }
+#endif
continue;
}
- // continually scale up to the final resolution. We upsample through lower layer numbers,
- // which have already been done until the final full res destination in m_ups_1.
- for (int target_layer = layer_number-1; target_layer >= 0; target_layer--)
+ frame_memory *ups_src = NULL;
+ int ups_prec = 0;
+
+ if (layer_number == frame_header.n_latent_n_resolutions-1)
+ {
+ // just upsample, no refinement.
+ ups_src = &m_plane_pyramid[layer_number];
+ ups_prec = ARM_PRECISION;
+ }
+ else
{
- frame_memory *dest;
- int dest_plane;
+ // refine, then upsample.
+#ifdef CCDECAPI_AVX2
+ if (m_use_avx2)
+ {
+ ups_refine_avx2(frame_header.ups_preconcat_k_size, m_upsw_preconcat[preconcat_layer].data, m_plane_pyramid[layer_number], m_ups_h2w2, ARM_PRECISION, m_ups_hw);
+ }
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ else
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
+ {
+ ups_refine_cpu(frame_header.ups_preconcat_k_size, m_upsw_preconcat[preconcat_layer].data, m_plane_pyramid[layer_number], m_ups_h2w2, ARM_PRECISION, m_ups_hw);
+ }
+#endif
+ ups_src = &m_ups_h2w2;
+ ups_prec = UPS_PRECISION;
+ }
+
+ for (int target_layer = layer_number-1; target_layer >= 0; ups_src = &m_plane_pyramid[target_layer], target_layer--, ups_prec = UPS_PRECISION)
+ {
+ // upsample layer index to use.
+ int ups_layer = (frame_header.n_latent_n_resolutions-2-target_layer)%m_ups_n;
+ // upsample, either to next pyramid level up or, instead of to [0], to m_syn_1[layer_number].
+ frame_memory *ups_dst = NULL;
+ int dst_plane;
if (target_layer == 0)
{
- dest = &m_ups_1;
- dest_plane = layer_number;
+ ups_dst = &m_syn_1;
+ dst_plane = layer_number;
}
else
{
- dest = &m_plane_pyramid[target_layer];
- dest_plane = 0;
+ ups_dst = &m_plane_pyramid[target_layer];
+ dst_plane = 0;
}
- frame_memory *lo_res = &m_plane_pyramid[target_layer+1];
- int pad = ups_ks/2/2;
- lo_res->custom_pad_replicate_plane_in_place_i(0, pad);
- //printf("upslayer%dtarget%d ", layer_number, target_layer); lo_res->print_ranges("ups_in", 1, target_layer == layer_number-1 ? ARM_PRECISION : UPS_PRECISION);
- // our output is not the direct origin -- we need an extra crop to point at pad start.
- custom_upsample_4(ups_ks, m_upsw_t_4x4.data, m_h_pyramid[target_layer+1]+2*pad, m_w_pyramid[target_layer+1]+2*pad, lo_res->stride,
- lo_res->pad_origin(pad), dest->stride, dest->pad_origin(dest_plane, 1), target_layer == layer_number-1);
+#ifdef CCDECAPI_AVX2
+ if (m_use_avx2)
+ {
+ ups_upsample_avx2(frame_header.ups_k_size, m_upsw[ups_layer].data, *ups_src, *ups_dst, dst_plane, ups_prec, m_ups_hw);
+ }
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ else
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
+ {
+ ups_upsample_cpu(frame_header.ups_k_size, m_upsw[ups_layer].data, *ups_src, *ups_dst, dst_plane, ups_prec, m_ups_hw);
+ }
+#endif
}
}
@@ -723,24 +678,162 @@ void cc_frame_decoder::run_ups(struct cc_bs_frame *frame_symbols)
time_ups_seconds += hand_ups_elapsed_seconds.count();
}
-// result points to either m_ups_1 or m_ups_2
-frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols)
+// blend an output into an accumulator.
+// syn_out = syn_out + syn_in*blend[branch_no]
+frame_memory *cc_frame_decoder::run_syn_blend1(struct cc_bs_frame *frame_symbols, int n_planes, int branch_no, frame_memory *syn_in, frame_memory *syn_out)
+{
+ const auto time_blend_start = std::chrono::steady_clock::now();
+#if 0
+#ifdef CCDECAPI_AVX2
+ if (m_use_avx2)
+ {
+ syn_blend1_avx2(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin());
+ }
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ else
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
+ {
+ syn_blend1(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin());
+ }
+#endif
+#endif
+
+ if (m_verbosity >= 100)
+ {
+ printf("BLEND1 INPUT\n");
+ for (int p = 0; p < n_planes; p++)
+ {
+ printf("BLENDINPLANE %d ", p);
+ syn_in->print_start(p, "BLENDIN", -1, UPS_PRECISION);
+ }
+ }
+
+ syn_blend1(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin());
+ const auto time_blend_end = std::chrono::steady_clock::now();
+ const std::chrono::duration hand_blend_elapsed_seconds = time_blend_end - time_blend_start;
+ time_blend_seconds += hand_blend_elapsed_seconds.count();
+ return syn_out;
+}
+
+// blend two outputs to initialilze an accumulator.
+// the 'in' is the 2nd syn output, the 'out' is the first syn output.
+// we keep updating 'out'
+// syn_out = syn_out*blend[branch_no-1] + syn_in*blend[branch_no]
+frame_memory *cc_frame_decoder::run_syn_blend2(struct cc_bs_frame *frame_symbols, int n_planes, int branch_no, frame_memory *syn_in, frame_memory *syn_out)
+{
+ const auto time_blend_start = std::chrono::steady_clock::now();
+#if 0
+#ifdef CCDECAPI_AVX2
+ if (m_use_avx2)
+ {
+ syn_blend2_avx2(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin(), m_syn_blends.data[branch_no-1]);
+ }
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ else
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
+ {
+ syn_blend2(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin(), m_syn_blends.data[branch_no-1]);
+ }
+#endif
+#endif
+
+ if (m_verbosity >= 100)
+ {
+ printf("BLEND2 INPUTS\n");
+ printf("1: w=%d(%f)\n", m_syn_blends.data[branch_no], (m_syn_blends.data[branch_no]/((1<print_start(p, "BLENDIN1", -1, UPS_PRECISION);
+ }
+ printf("2: w=%d(%f)\n", m_syn_blends.data[branch_no-1], (m_syn_blends.data[branch_no]/((1<print_start(p, "BLENDIN2", -1, UPS_PRECISION);
+ }
+ }
+
+ syn_blend2(syn_in->h, syn_in->w, syn_in->stride, syn_in->plane_stride, n_planes, syn_in->origin(), m_syn_blends.data[branch_no], syn_out->origin(), m_syn_blends.data[branch_no-1]);
+ const auto time_blend_end = std::chrono::steady_clock::now();
+ const std::chrono::duration hand_blend_elapsed_seconds = time_blend_end - time_blend_start;
+ time_blend_seconds += hand_blend_elapsed_seconds.count();
+ return syn_out;
+}
+
+// operate synth layers starting from syn_in as input.
+// syn_out non-NULL implies the first syn layer will EXPLICITLY target syn_out, thereafter remaining in syn_out if possible (alternating syn_out, syn_tmp if necessary)
+// return value is either syn_out or syn_tmp.
+// syn_out NULL implies syn processing will remain in syn_in as much as possible (alternating syn_in, syn_tmp if necessary)
+// return value is either syn_in or syn_tmp.
+frame_memory *cc_frame_decoder::run_syn_branch(struct cc_bs_frame *frame_symbols, int branch_no, frame_memory *syn_in, frame_memory *syn_out, frame_memory *syn_tmp)
{
struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header;
- // NEW SYNTHESIS: IN-PLACE CONV
int n_syn_in = frame_header.n_latent_n_resolutions;
- frame_memory *syn_in_i = &m_ups_1;
- frame_memory *syn_out_i = &m_ups_2; // if cannot do syn in-place.
+ frame_memory *syn_in_i = NULL; // internal in/out
+ frame_memory *syn_out_i = NULL; // internal in/out
+ if (syn_out == syn_in)
+ syn_out = NULL;
+
const auto time_syn_start = std::chrono::steady_clock::now();
for (int syn_idx = 0; syn_idx < frame_header.n_syn_layers; syn_idx++)
{
- int n_syn_out = frame_header.layers_synthesis[syn_idx].n_out_ft;
- printf("SYN%d: k=%d #in=%d #out=%d", syn_idx, frame_header.layers_synthesis[syn_idx].ks, n_syn_in, n_syn_out);
auto time_this_syn_start = std::chrono::steady_clock::now();
+ int n_syn_out = frame_header.layers_synthesis[syn_idx].n_out_ft;
+ int syn_wb_idx = get_syn_idx(branch_no, syn_idx);
+ if (m_verbosity >= 3)
+ printf("SYN%d: k=%d #in=%d #out=%d", syn_idx, frame_header.layers_synthesis[syn_idx].ks, n_syn_in, n_syn_out);
+
+ bool possible_in_place;
+#ifdef CCDECAPI_AVX2
+ if (m_use_avx2)
+ {
+ possible_in_place = frame_header.layers_synthesis[syn_idx].ks == 3
+ || (syn_idx == 0 && can_fuse(frame_symbols))
+ || (frame_header.layers_synthesis[syn_idx].ks == 1 && n_syn_out <= 4);
+ }
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ else
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
+ {
+ possible_in_place = true;
+ }
+#endif
+
+ // forcing output to not in-place for first layer?
+ bool in_place = possible_in_place;
+ if (syn_idx == 0 && syn_out != NULL)
+ {
+ in_place = false; // explit targeting of non-input for output. We will switch between syn_out, syn_tmp after as necessary.
+ }
+
+ if (syn_idx == 0)
+ {
+ syn_in_i = syn_in;
+ syn_out_i = syn_out != NULL ? syn_out : (in_place ? syn_in_i : syn_tmp);
+ }
+ else
+ {
+ if (syn_out != NULL)
+ {
+ // syn_in_i is out or tmp. determine syn_out_i
+ syn_out_i = in_place ? syn_in_i : (syn_in_i == syn_out ? syn_tmp : syn_out);
+ }
+ else
+ {
+ // syn_in_i is in or tmp. determine syn_out_i
+ syn_out_i = in_place ? syn_in_i : (syn_in_i == syn_in ? syn_tmp : syn_in);
+ }
+ }
- //syn_in_i->print_ranges("syn_in", n_syn_in, SYN_MUL_PRECISION);
if (frame_header.layers_synthesis[syn_idx].ks > 1)
{
// pad syn_in, and direct output to syn_out.
@@ -751,81 +844,88 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols)
int pad_origin_idx = origin_idx-pad*syn_in_i->stride-pad;
for (int l = 0; l < n_syn_in; l++)
syn_in_i->custom_pad_replicate_plane_in_place_i(l, pad);
- printf(" padded ");
+ if (m_verbosity >= 3)
+ printf(" padded ");
#ifdef CCDECAPI_AVX2
- bool in_place = false;
- if (frame_header.layers_synthesis[syn_idx].ks == 3)
+ if (m_use_avx2)
{
- // optimized 3x3 with line buffer.
- in_place = true;
- m_syn3x3_linebuffer.update_to(2*n_syn_out*m_gop_header.img_w); // two lines for now, could eventually be 1 line plus a few pixels.
- if (n_syn_in == 3 && n_syn_out == 3)
- custom_conv_ks3_in3_out3_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data,
- h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL,
- m_syn3x3_linebuffer.data,
- !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if (n_syn_in == 6 && n_syn_out == 6)
- custom_conv_ks3_in6_out6_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data,
- h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL,
- m_syn3x3_linebuffer.data,
- !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if (n_syn_in == 9 && n_syn_out == 6)
- custom_conv_ks3_in9_out6_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data,
- h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL,
- m_syn3x3_linebuffer.data,
- !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if (n_syn_in == 9 && n_syn_out == 9)
- custom_conv_ks3_in9_out9_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data,
- h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL,
- m_syn3x3_linebuffer.data,
- !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ if (frame_header.layers_synthesis[syn_idx].ks == 3)
+ {
+ // optimized 3x3 with line buffer.
+ // in_place = true;
+ m_syn3x3_linebuffer.update_to(2*n_syn_out*m_gop_header.img_w); // two lines for now, could eventually be 1 line plus a few pixels.
+ if (n_syn_in == 3 && n_syn_out == 3)
+ custom_conv_ks3_in3_out3_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data,
+ h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
+ m_syn3x3_linebuffer.data,
+ !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if (n_syn_in == 6 && n_syn_out == 6)
+ custom_conv_ks3_in6_out6_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data,
+ h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
+ m_syn3x3_linebuffer.data,
+ !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if (n_syn_in == 9 && n_syn_out == 6)
+ custom_conv_ks3_in9_out6_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data,
+ h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
+ m_syn3x3_linebuffer.data,
+ !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if (n_syn_in == 9 && n_syn_out == 9)
+ custom_conv_ks3_in9_out9_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data,
+ h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
+ m_syn3x3_linebuffer.data,
+ !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else
+ custom_conv_ks3_inX_outX_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data,
+ h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
+ m_syn3x3_linebuffer.data,
+ !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ }
else
- custom_conv_ks3_inX_outX_lb_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data,
- h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL,
- m_syn3x3_linebuffer.data,
+ {
+ custom_conv_ksX_inX_outX_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data,
+ h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
!!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ }
}
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
else
- custom_conv_ksX_inX_outX_avx2(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data,
- h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
- !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
-#else
- bool in_place = false;
- if (frame_header.layers_synthesis[syn_idx].ks == 3)
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
{
- in_place = true;
- m_syn3x3_linebuffer.update_to(2*n_syn_out*m_gop_header.img_w); // two lines for now, could eventually be 1 line plus a few pixels.
- custom_conv_ks3_inX_outX_lb(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data,
- h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, NULL,
- m_syn3x3_linebuffer.data,
- !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ if (frame_header.layers_synthesis[syn_idx].ks == 3)
+ {
+ m_syn3x3_linebuffer.update_to(2*n_syn_out*m_gop_header.img_w); // two lines for now, could eventually be 1 line plus a few pixels.
+ custom_conv_ks3_inX_outX_lb(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data,
+ h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
+ m_syn3x3_linebuffer.data,
+ !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ }
+ else
+ {
+ custom_conv_ksX_inX_outX(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data,
+ h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
+ !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ }
}
- else
- custom_conv_ksX_inX_outX(frame_header.layers_synthesis[syn_idx].ks, m_synw[syn_idx].data, m_synb[syn_idx].data,
- h_padded, w_padded, syn_in_i->stride, syn_in_i->plane_stride, origin_idx-pad_origin_idx, n_syn_in, syn_in_i->raw()+pad_origin_idx, n_syn_out, syn_out_i->origin(),
- !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
#endif
if (!in_place)
{
- // swap our in & out.
- frame_memory *tmp = syn_in_i;
- syn_in_i = syn_out_i;
- syn_out_i = tmp;
+ // swap our in & out. accounting for any forced-first-time non-in-place.
+ // we just update syn_in_i here, syn_out_i is calculated every loop.
+ syn_in_i = syn_out_i;
}
}
else
{
// possible in-place if ATATIME is big enough to buffer all outputs for single spatial pixel.
// assuming max 4 here (avx2) and infinite (cpu)
- bool in_place = n_syn_out <= 4; // !!! hardwired, not too many outputs.
- // possible fusion.
bool fused = syn_idx == 0 && can_fuse(frame_symbols);
if (fused)
{
// compatible!
- in_place = true;
int n_hidden = n_syn_out;
n_syn_out = frame_header.layers_synthesis[syn_idx+1].n_out_ft;
@@ -833,55 +933,82 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols)
int32_t synw_fused[n_syn_in*n_hidden];
int kidx = 0;
for (int kx = 0; kx < n_syn_in; kx++)
+ {
for (int ky = 0; ky < n_hidden; ky++)
{
- synw_fused[kidx++] = m_synw[syn_idx].data[ky*n_syn_in+kx];
+ synw_fused[kidx++] = m_synw[syn_wb_idx].data[ky*n_syn_in+kx];
}
-#ifdef CCDECAPI_AVX2
- if (n_hidden == 40)
+ }
+#if defined(CCDECAPI_AVX2)
+ if (m_use_avx2 && n_syn_in == 7)
{
- if (n_syn_out == 3)
- custom_conv_ks1_in7_hidden40_out3_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin());
- else if (n_syn_out == 6)
- custom_conv_ks1_in7_hidden40_out6_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin());
+ if (n_hidden == 48 && n_syn_out == 3)
+ custom_conv_ks1_in7_hidden48_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
+ else if (n_hidden == 40)
+ {
+ if (n_syn_out == 3)
+ custom_conv_ks1_in7_hidden40_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
+ else if (n_syn_out == 6)
+ custom_conv_ks1_in7_hidden40_out6_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
+ else if (n_syn_out == 9)
+ custom_conv_ks1_in7_hidden40_out9_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
+ else
+ custom_conv_ks1_inX_hiddenX_outX(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
+ }
+ else if (n_hidden == 32 && n_syn_out == 3)
+ custom_conv_ks1_in7_hidden32_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
+ else if (n_hidden == 16 && n_syn_out == 3)
+ custom_conv_ks1_in7_hidden16_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
+ else if (n_hidden == 8 && n_syn_out == 3)
+ custom_conv_ks1_in7_hidden8_out3_avx2(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
else
- custom_conv_ks1_in7_hidden40_out9_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin());
+ custom_conv_ks1_inX_hiddenX_outX(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
}
- else if (n_hidden == 16)
- custom_conv_ks1_in7_hidden16_out3_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin());
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
else
- custom_conv_ks1_in7_hidden8_out3_avx2(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin());
-#else
- custom_conv_ks1_inX_hiddenX_outX(1, synw_fused, m_synb[syn_idx].data, m_synw[syn_idx+1].data, m_synb[syn_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_in_i->origin());
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
+ custom_conv_ks1_inX_hiddenX_outX(1, synw_fused, m_synb[syn_wb_idx].data, m_synw[syn_wb_idx+1].data, m_synb[syn_wb_idx+1].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, n_syn_in, n_hidden, syn_in_i->origin(), n_syn_out, syn_out_i->origin());
#endif
}
else
{
// 1x1 kernel but not fused.
#ifdef CCDECAPI_AVX2
- if (n_syn_in == 7 && n_syn_out == 9)
- custom_conv_ks1_in7_out9_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if (n_syn_in == 9 && n_syn_out == 3)
- custom_conv_ks1_in9_out3_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if (n_syn_in == 9 && n_syn_out == 6)
- custom_conv_ks1_in9_out6_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if (n_syn_in == 9 && n_syn_out == 9)
- custom_conv_ks1_in9_out9_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if (n_syn_in == 7 && n_syn_out == 40)
- custom_conv_ks1_in7_out40_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if (n_syn_in == 7 && n_syn_out == 16)
- custom_conv_ks1_in7_out16_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if ( n_syn_out == 3)
- custom_conv_ks1_inX_out3_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if ( n_syn_out == 6)
- custom_conv_ks1_inX_out6_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else if ( n_syn_out == 9)
- custom_conv_ks1_inX_out9_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
- else
- custom_conv_ks1_inX_outX_avx2(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
-#else
- in_place = true;
- custom_conv_ks1_inX_outX(1, m_synw[syn_idx].data, m_synb[syn_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, in_place ? syn_in_i->origin() : syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ if (m_use_avx2)
+ {
+ if (n_syn_in == 7 && n_syn_out == 9)
+ custom_conv_ks1_in7_out9_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if (n_syn_in == 9 && n_syn_out == 3)
+ custom_conv_ks1_in9_out3_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if (n_syn_in == 9 && n_syn_out == 6)
+ custom_conv_ks1_in9_out6_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if (n_syn_in == 9 && n_syn_out == 9)
+ custom_conv_ks1_in9_out9_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if (n_syn_in == 7 && n_syn_out == 40)
+ custom_conv_ks1_in7_out40_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if (n_syn_in == 7 && n_syn_out == 16)
+ custom_conv_ks1_in7_out16_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if ( n_syn_out == 3)
+ custom_conv_ks1_inX_out3_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if ( n_syn_out == 6)
+ custom_conv_ks1_inX_out6_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else if ( n_syn_out == 9)
+ custom_conv_ks1_inX_out9_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ else
+ custom_conv_ks1_inX_outX_avx2(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ }
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ else
+#endif
+#if defined(CCDECAPI_AVX2_OPTIONAL) || defined(CCDECAPI_CPU)
+ {
+ // can buffer everything in cpu-mode.
+ // in_place = true;
+ custom_conv_ks1_inX_outX(1, m_synw[syn_wb_idx].data, m_synb[syn_wb_idx].data, m_gop_header.img_h, m_gop_header.img_w, syn_in_i->stride, syn_in_i->plane_stride, 0, n_syn_in, syn_in_i->origin(), n_syn_out, syn_out_i->origin(), !!frame_header.layers_synthesis[syn_idx].residual, !!frame_header.layers_synthesis[syn_idx].relu);
+ }
#endif
}
if (fused)
@@ -891,10 +1018,9 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols)
}
if (!in_place)
{
- // swap our in & out.
- frame_memory *tmp = syn_in_i;
+ // swap our in & out. accounting for any forced-first-time non-in-place.
+ // we just update syn_in_i here, syn_out_i is calculated every loop.
syn_in_i = syn_out_i;
- syn_out_i = tmp;
}
}
@@ -904,9 +1030,9 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols)
const auto time_this_syn_end = std::chrono::steady_clock::now();
const std::chrono::duration this_syn_elapsed_seconds = time_this_syn_end - time_this_syn_start;
- printf(" %g\n", (double)this_syn_elapsed_seconds.count());
+ if (m_verbosity >= 3)
+ printf(" %g\n", (double)this_syn_elapsed_seconds.count());
}
- //syn_in_i->print_ranges("final", n_syn_in, SYN_MUL_PRECISION);
const auto time_syn_end = std::chrono::steady_clock::now();
const std::chrono::duration syn_elapsed_seconds = time_syn_end - time_syn_start;
@@ -915,6 +1041,113 @@ frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols)
return syn_in_i;
}
+frame_memory *cc_frame_decoder::run_syn(struct cc_bs_frame *frame_symbols)
+{
+ // multiple synthesiser runs.
+ // we have m_syn_1 containing upsampled output, ready to send to synthesisers.
+ //
+ // ASSUMING ALL SYNTHS POSSIBLE IN-PLACE. IE, ALL KS ARE 1 or 3.
+ // with n branches 1..n:
+ // for branch1:
+ // synth m_syn_1 -> m_syn_2. (ie, 1st syn layer takes 1->2, remaining layers in 2)
+ // for branch 2..n-1:
+ // synth m_syn_1 -> m_syn_3. (ie, 1st syn layer takes 1->3, remaining layers in 3)
+ // blend result m_syn_3 into m_syn_2.
+ // for branch n:
+ // synth m_syn_1 in place. (ie, all layers in 1)
+ // blend result m_syn_1 into m_syn_2.
+ //
+ // result is m_syn_2.
+
+ // IF NOT IN-PLACE (KS > 3)
+ // with n branches 1..n:
+ // for branch1:
+ // synth m_syn_1 -> m_syn_2 or m_syn_x. (ie, 1st syn layer 1->2, remaining layers in 2/x)
+ // for branch 2..n-1:
+ // synth m_syn_1 -> m_syn_3 or m_syn_x. (ie, 1st syn layer 1->3, remaining layers in 3/x)
+ // blend result m_syn_3/x into m_syn_2.
+ // for branch n:
+ // synth m_syn_1 in place or m_syn_x . (ie, all layers in 1)
+ // blend result m_syn_1 into m_syn_2.
+
+ frame_memory *p_syn_1 = &m_syn_1; // initial upsample output -- preserved until last branch.
+ frame_memory *p_syn_2 = &m_syn_2; // initial branch output, blended into by subsequent branches.
+ frame_memory *p_syn_3 = &m_syn_3; // subsequence branch outputs
+ frame_memory *p_syn_tmp = &m_syn_tmp;
+
+ if (m_verbosity >= 100)
+ {
+ struct cc_bs_frame_header &frame_header = frame_symbols->m_frame_header;
+ printf("SYNTHESIS INPUTS\n");
+ for (int p = 0; p < frame_header.n_latent_n_resolutions; p++)
+ {
+ printf("SYNPLANE %d ", p);
+ m_syn_1.print_start(p, "SYNIN", -1, UPS_PRECISION);
+ }
+ }
+
+ frame_memory *result = NULL;
+ for (int b = 0; b < m_syn_n_branches; b++)
+ {
+ if (b == 0 && m_syn_n_branches > 1)
+ {
+ result = run_syn_branch(frame_symbols, b, p_syn_1, p_syn_2, p_syn_tmp);
+ // result is either m_syn_2 or m_syn_tmp.
+ if (result != p_syn_2)
+ {
+ // switch our notion of m_syn_2 and m_syn_tmp.
+ p_syn_tmp = p_syn_2;
+ p_syn_2 = result;
+ }
+ }
+ else if (b < m_syn_n_branches-1)
+ {
+ result = run_syn_branch(frame_symbols, b, p_syn_1, p_syn_3, p_syn_tmp);
+ // result is either m_syn_3 or m_syn_tmp.
+ if (result != p_syn_3)
+ {
+ // switch our notion of m_syn_3 and m_syn_tmp.
+ p_syn_tmp = p_syn_3;
+ p_syn_3 = result;
+ }
+ if (b == 1)
+ run_syn_blend2(frame_symbols, 3, b, p_syn_3, p_syn_2);
+ else
+ run_syn_blend1(frame_symbols, 3, b, p_syn_3, p_syn_2);
+ }
+ else
+ {
+ // NULL for out implies we can run in-place as much as we like, destroying p_syn_1.
+ result = run_syn_branch(frame_symbols, b, p_syn_1, NULL, p_syn_tmp);
+ // result is either m_syn_1 or m_syn_tmp.
+ if (result != p_syn_1)
+ {
+ // switch our notion of m_syn_1 and m_syn_tmp.
+ p_syn_tmp = p_syn_1;
+ p_syn_1 = result;
+ }
+ if (b == 0)
+ {
+ ; // we don't have any blending, single branch.
+ }
+ else if (b == 1)
+ {
+ // 2 branches being blended to create result.
+ run_syn_blend2(frame_symbols, 3, b, p_syn_1, p_syn_2);
+ result = p_syn_2;
+ }
+ else
+ {
+ // subsequent branch being blended into result.
+ run_syn_blend1(frame_symbols, 3, b, p_syn_1, p_syn_2);
+ result = p_syn_2;
+ }
+ }
+ }
+
+ return result;
+}
+
// returned value should be transformed or otherwise copied before calling decode_frame again.
struct frame_memory *cc_frame_decoder::decode_frame(struct cc_bs_frame *frame_symbols)
{
@@ -922,8 +1155,11 @@ struct frame_memory *cc_frame_decoder::decode_frame(struct cc_bs_frame *frame_sy
read_ups(frame_symbols);
read_syn(frame_symbols);
- printf("loaded weights\n");
- fflush(stdout);
+ if (m_verbosity >= 3)
+ {
+ printf("loaded weights\n");
+ fflush(stdout);
+ }
check_allocations(frame_symbols);
diff --git a/coolchic/cpp/cc-frame-decoder.h b/coolchic/cpp/cc-frame-decoder.h
index 51368941..9cc719ef 100644
--- a/coolchic/cpp/cc-frame-decoder.h
+++ b/coolchic/cpp/cc-frame-decoder.h
@@ -5,20 +5,52 @@
class cc_frame_decoder
{
public:
- cc_frame_decoder(struct cc_bs_gop_header &gop_header)
+ cc_frame_decoder(struct cc_bs_gop_header &gop_header, int output_bitdepth, int output_chroma_format
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ , bool use_avx2
+#endif
+ , int verbosity
+ )
: m_gop_header(gop_header),
+ m_output_bitdepth(output_bitdepth),
+ m_output_chroma_format(output_chroma_format),
m_mlpw_t(NULL),
m_mlpb(NULL),
m_mlp_n_hidden_layers_arm(-1),
+ m_ups_n(-1),
+ m_upsw(NULL),
+ m_ups_n_preconcat(-1),
+ m_upsw_preconcat(NULL),
+ m_syn_n_branches(-1),
+ m_syn_n_layers(-1),
m_synw(NULL),
- m_synb(NULL),
- m_syn_n_layers(-1)
- {};
-
- ~cc_frame_decoder() {delete[] m_mlpw_t; delete[] m_mlpb; delete[] m_synw; delete[] m_synb;};
+ m_synb(NULL)
+#if defined(CCDECAPI_AVX2)
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ ,m_use_avx2(use_avx2)
+#else
+ ,m_use_avx2(true)
+#endif
+#endif
+ ,m_verbosity(verbosity)
+ ,m_arm_pad(0)
+ ,m_ups_pad(0)
+ ,m_max_pad(0)
+ {
+ if (m_output_bitdepth == 0)
+ m_output_bitdepth = m_gop_header.bitdepth;
+ if (m_output_chroma_format == 0)
+ m_output_chroma_format = m_gop_header.frame_data_type == 1 ? 420 : 444;
+ };
+
+ ~cc_frame_decoder() {delete[] m_mlpw_t; delete[] m_mlpb; delete[] m_upsw; delete[] m_upsw_preconcat; delete[] m_synw; delete[] m_synb;};
+public:
+ int get_syn_idx(int branch_idx, int layer_idx) { return branch_idx*m_syn_n_layers + layer_idx; }
public:
struct cc_bs_gop_header &m_gop_header;
+ int m_output_bitdepth;
+ int m_output_chroma_format;
public:
struct frame_memory *decode_frame(struct cc_bs_frame *frame_symbols);
@@ -30,14 +62,24 @@ class cc_frame_decoder
weights_biases m_mlpbOUT;
int m_mlp_n_hidden_layers_arm;
- weights_biases m_upsw_t_4x4;
+ int m_ups_n;
+ weights_biases *m_upsw;
+ int m_ups_n_preconcat;
+ weights_biases *m_upsw_preconcat;
- weights_biases *m_synw;
- weights_biases *m_synb;
+ int m_syn_n_branches;
int m_syn_n_layers;
+ weights_biases m_syn_blends; // m_syn_n_branches blend values.
+ weights_biases *m_synw; // [branchidx*m_syn_n_layers+synidx]
+ weights_biases *m_synb; // [branchidx*m_syn_n_layers+synidx]
buffer m_syn3x3_linebuffer;
+#if defined(CCDECAPI_AVX2)
+ bool m_use_avx2;
+#endif
+ int m_verbosity;
+
private:
// for arm decode destination, and upsampling.
std::vector m_h_pyramid;
@@ -49,8 +91,12 @@ class cc_frame_decoder
int m_ups_pad;
int m_max_pad;
- frame_memory m_ups_1; // !!! really for_syn_0
- frame_memory m_ups_2; // !!! really for_syn_1
+ frame_memory m_ups_h2w2; // internal refinement target prior to upsample.
+ frame_memory m_ups_hw; // during 2-pass refinement and 2-padd upsample.
+ frame_memory m_syn_1;
+ frame_memory m_syn_2;
+ frame_memory m_syn_3;
+ frame_memory m_syn_tmp;
private:
// set by arm to indicate no content, avoid ups for no content.
@@ -65,6 +111,9 @@ class cc_frame_decoder
void run_arm(struct cc_bs_frame *frame_symbols);
void run_ups(struct cc_bs_frame *frame_symbols);
+ frame_memory *run_syn_blend1(struct cc_bs_frame *frame_symbols, int n_planes, int branch_no, frame_memory *syn_in, frame_memory *syn_out);
+ frame_memory *run_syn_blend2(struct cc_bs_frame *frame_symbols, int n_planes, int branch_no, frame_memory *syn_in, frame_memory *syn_out);
+ frame_memory *run_syn_branch(struct cc_bs_frame *frame_symbols, int branch_no, frame_memory *syn_in, frame_memory *syn_out, frame_memory *syn_tmp);
frame_memory *run_syn(struct cc_bs_frame *frame_symbols);
};
diff --git a/coolchic/cpp/ccdecapi.hpp b/coolchic/cpp/ccdecapi.cpp
similarity index 77%
rename from coolchic/cpp/ccdecapi.hpp
rename to coolchic/cpp/ccdecapi.cpp
index 905d9610..e754766c 100644
--- a/coolchic/cpp/ccdecapi.hpp
+++ b/coolchic/cpp/ccdecapi.cpp
@@ -4,8 +4,14 @@
* #define CCDECAPI_CPU
* or
* #define CCDECAPI_AVX2
+ * or
+ * #define CCDECAPI_AVX2_OPTIONAL
*/
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+#define CCDECAPI_AVX2
+#endif
+
#include
#include
#include
@@ -28,55 +34,97 @@ float time_bac_seconds = 0.0;
float time_arm_seconds = 0.0;
float time_ups_seconds = 0.0;
float time_syn_seconds = 0.0;
+float time_blend_seconds = 0.0;
float time_warp_seconds = 0.0;
float time_bpred_seconds = 0.0;
float time_all_seconds = 0.0;
+// like std::byteswap, but that's only in c++23 onwards!
+inline unsigned short byteswap(unsigned short x)
+{
+ unsigned short hi = x&0xFF00;
+ unsigned short lo = x&0x00FF;
+ return (lo<<8)|(hi>>8);
+}
+
+// like ends_with in c++20 and onwards.
+inline bool ends_with(const std::string& a, const std::string& b)
+{
+ if (b.size() > a.size())
+ return false;
+ return std::equal(a.begin() + a.size() - b.size(), a.end(), b.begin());
+}
+
// only use nc == 3
-void ppm_out(int nc, int pixel_depth, struct frame_memory &in_info, char const *outname)
+void ppm_out(int nc, int bit_depth, struct frame_memory &in_info, FILE *fout)
{
int const h = in_info.h;
int const w = in_info.w;
int const stride = in_info.stride;
int const plane_stride = in_info.plane_stride;
+ int const max_sample_val = (1<> SYN_LAYER_PRECISION;
- int g = ((*inG++)*255+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION;
- int b = ((*inB++)*255+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION;
+ int r = ((*inR++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION;
+ int g = ((*inG++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION;
+ int b = ((*inB++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION;
if (r < 0) r = 0;
if (g < 0) g = 0;
if (b < 0) b = 0;
- if (r > 255) r = 255;
- if (g > 255) g = 255;
- if (b > 255) b = 255;
+ if (r > max_sample_val) r = max_sample_val;
+ if (g > max_sample_val) g = max_sample_val;
+ if (b > max_sample_val) b = max_sample_val;
unsigned char pix[3];
pix[0] = r;
pix[1] = g;
pix[2] = b;
- fwrite(pix, 3, 1, fout);
+ fwrite(pix, 3, sizeof(pix[0]), fout);
+ }
+ }
+ else
+ {
+ for (int y = 0; y < h; y++, inR += stride-w, inG += stride-w, inB += stride-w)
+ for (int x = 0; x < w; x++)
+ {
+ // precision is SYN_LAYER_PRECISION
+ int r = ((*inR++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION;
+ int g = ((*inG++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION;
+ int b = ((*inB++)*max_sample_val+(1<<(SYN_LAYER_PRECISION-1))) >> SYN_LAYER_PRECISION;
+ if (r < 0) r = 0;
+ if (g < 0) g = 0;
+ if (b < 0) b = 0;
+ if (r > max_sample_val) r = max_sample_val;
+ if (g > max_sample_val) g = max_sample_val;
+ if (b > max_sample_val) b = max_sample_val;
+ unsigned short pix[3];
+ pix[0] = byteswap(r);
+ pix[1] = byteswap(g);
+ pix[2] = byteswap(b);
+ fwrite(pix, 3, sizeof(pix[0]), fout);
}
}
- fclose(fout);
+ fflush(fout);
}
// we 'stabilise' the yuv to (0..255) as integers, as well as uv subsample.
@@ -294,7 +342,7 @@ unsigned short *get_raw_444_10b(struct frame_memory &in)
// incoming 420 is planar, outgoing 444 is planar.
void convert_420_444_8b(frame_memory &out, unsigned char *in, int h, int w)
{
- out.update_to(h, w, 3, 0);
+ out.update_to(h, w, 0, 3);
unsigned char *src = in;
int32_t *dst = out.plane_origin(0);
@@ -329,7 +377,7 @@ void convert_420_444_8b(frame_memory &out, unsigned char *in, int h, int w)
// incoming 420 is planar, outgoing 444 is planar.
void convert_420_444_10b(frame_memory &out, unsigned short *in, int h, int w)
{
- out.update_to(h, w, 3, 0);
+ out.update_to(h, w, 0, 3);
unsigned short *src = in;
int32_t *dst = out.plane_origin(0);
@@ -397,7 +445,7 @@ void dump_yuv444_10b(unsigned short *frame_444, int h, int w, FILE *fout, int fr
// incoming 444 is planar, outgoing 444 is planar.
void store_444_8b(frame_memory &out, unsigned char *in, int h, int w)
{
- out.update_to(h, w, 3, 0);
+ out.update_to(h, w, 0, 3);
unsigned char *src = in;
int32_t *dst = out.plane_origin(0);
@@ -421,7 +469,7 @@ void store_444_8b(frame_memory &out, unsigned char *in, int h, int w)
// incoming 444 is planar, outgoing 444 is planar.
void store_444_10b(frame_memory &out, unsigned short *in, int h, int w)
{
- out.update_to(h, w, 3, 0);
+ out.update_to(h, w, 0, 3);
unsigned short *src = in;
int32_t *dst = out.plane_origin(0);
@@ -449,7 +497,7 @@ void warp(struct frame_memory &warp_result, struct frame_memory &raw_info, int r
{
const auto time_warp_start = std::chrono::steady_clock::now();
- warp_result.update_to(ref.h, ref.w, 3, 0);
+ warp_result.update_to(ref.h, ref.w, 0, 3);
int32_t *src = ref.origin();
int const src_stride = ref.stride;
@@ -560,7 +608,7 @@ void bpred(struct frame_memory &bpred_result, struct frame_memory &raw_info, int
int raw_stride = raw_info.stride;
int raw_plane_stride = raw_info.plane_stride;
- bpred_result.update_to(h, w, 3, 0);
+ bpred_result.update_to(h, w, 0, 3);
int32_t *out = bpred_result.origin();
int32_t *src0 = pred0.origin();
@@ -622,14 +670,14 @@ void process_inter(struct frame_memory &pred, struct frame_memory &raw_cc_output
}
}
-#ifdef CCDECAPI_CPU
-int cc_decode_cpu(std::string &bitstream_filename, std::string &out_filename)
-#else
-#ifdef CCDECAPI_AVX2
-int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
+#if defined(CCDECAPI_CPU)
+int cc_decode_cpu(std::string &bitstream_filename, std::string &out_filename, int output_bitdepth, int output_chroma_format, int verbosity)
+#elif defined(CCDECAPI_AVX2_OPTIONAL)
+int cc_decode_avx2_optional(std::string &bitstream_filename, std::string &out_filename, int output_bitdepth, int output_chroma_format, bool use_avx2, int verbosity)
+#elif defined(CCDECAPI_AVX2)
+int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename, int output_bitdepth, int output_chroma_format, int verbosity)
#else
-#error must have one of CCDECAPI_CPU or CCDECAPI_AVX2 defined.
-#endif
+#error must have one of CCDECAPI_CPU, CCDECAPI_AVX2 or CCDECAPI_AVX2_OPTIONAL defined.
#endif
{
if (bitstream_filename == "")
@@ -641,7 +689,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
const auto time_all_start = std::chrono::steady_clock::now();
cc_bs bs;
- if (!bs.open(bitstream_filename))
+ if (!bs.open(bitstream_filename, verbosity))
{
printf("cannot open %s for reading\n", bitstream_filename.c_str());
return 1;
@@ -661,7 +709,11 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
}
}
- struct cc_frame_decoder frame_decoder(gop_header);
+#if defined(CCDECAPI_AVX2_OPTIONAL)
+ struct cc_frame_decoder frame_decoder(gop_header, output_bitdepth, output_chroma_format, use_avx2, verbosity);
+#else
+ struct cc_frame_decoder frame_decoder(gop_header, output_bitdepth, output_chroma_format, verbosity);
+#endif
for (int frame_coding_idx = 0; frame_coding_idx <= gop_header.intra_period; frame_coding_idx++)
{
// raw_cc_output either
@@ -670,7 +722,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
// [9] (B: residue+xy+alpha+xy+beta)
// from bitstream.
- cc_bs_frame *frame_symbols = bs.decode_frame();
+ cc_bs_frame *frame_symbols = bs.decode_frame(verbosity);
if (frame_symbols == NULL)
{
return 1;
@@ -706,7 +758,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
if (frames_444[idx].raw() != NULL)
{
ref_prev = &frames_444[idx];
- printf("refprev: %d\n", idx);
+ //printf("refprev: %d\n", idx);
break;
}
// find next if B.
@@ -715,7 +767,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
if (frames_444[idx].raw() != NULL)
{
ref_next = &frames_444[idx];
- printf("refnext: %d\n", idx);
+ //printf("refnext: %d\n", idx);
break;
}
process_inter(frame_444, *raw_cc_output, gop_header.img_h, gop_header.img_w, ref_prev, ref_next, frame_header.flow_gain);
@@ -723,9 +775,9 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
}
// YUV 420
- if (gop_header.frame_data_type == 1)
+ if (ends_with(out_filename, ".yuv") && frame_decoder.m_output_chroma_format == 420)
{
- if (gop_header.bitdepth == 8)
+ if (frame_decoder.m_output_bitdepth == 8)
{
unsigned char *frame_420 = convert_444_420_8b(*frame_444p);
if (out_filename != "")
@@ -733,7 +785,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
convert_420_444_8b(frames_444[frame_header.display_index], frame_420, gop_header.img_h, gop_header.img_w);
delete[] frame_420;
}
- else if (gop_header.bitdepth == 10)
+ else if (frame_decoder.m_output_bitdepth == 10)
{
unsigned short *frame_420 = convert_444_420_10b(*frame_444p);
if (out_filename != "")
@@ -744,13 +796,14 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
}
else
{
- printf("Unkown YUV bitdepth %d. Should be 8 or 10.", gop_header.bitdepth);
+ printf("Unknown YUV bitdepth %d. Should be 8 or 10.\n", frame_decoder.m_output_bitdepth);
+ exit(1);
}
}
// YUV 444
- else if (gop_header.frame_data_type == 2)
+ else if (ends_with(out_filename, ".yuv") && frame_decoder.m_output_chroma_format == 444)
{
- if (gop_header.bitdepth == 8)
+ if (frame_decoder.m_output_bitdepth == 8)
{
unsigned char *raw_frame_444 = get_raw_444_8b(*frame_444p);
if (out_filename != "")
@@ -758,7 +811,7 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
store_444_8b(frames_444[frame_header.display_index], raw_frame_444, gop_header.img_h, gop_header.img_w);
}
- else if (gop_header.bitdepth == 10)
+ else if (frame_decoder.m_output_bitdepth == 10)
{
unsigned short *raw_frame_444 = get_raw_444_10b(*frame_444p);
if (out_filename != "")
@@ -767,15 +820,15 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
}
else
{
- printf("Unkown YUV bitdepth %d. Should be 8 or 10.", gop_header.bitdepth);
+ printf("Unknown YUV bitdepth %d. Should be 8 or 10.\n", frame_decoder.m_output_bitdepth);
+ exit(1);
}
}
-
else
{
// rgb
if (out_filename != "")
- ppm_out(3, 3, *frame_444p, out_filename.c_str());
+ ppm_out(3, frame_decoder.m_output_bitdepth, *frame_444p, fout);
if (gop_header.intra_period > 0)
{
printf("do not want to copy rgb in rgb video\n");
@@ -788,23 +841,39 @@ int cc_decode_avx2(std::string &bitstream_filename, std::string &out_filename)
const std::chrono::duration elapsed_all = (time_all_end-time_all_start);
time_all_seconds = (float)elapsed_all.count();
- printf("time: arm %g ups %g syn %g warp %g bpred %g all %g\n",
- time_arm_seconds, time_ups_seconds, time_syn_seconds, time_warp_seconds, time_bpred_seconds, time_all_seconds);
+ if (verbosity >= 1)
+ printf("time: arm %g ups %g syn %g blend %g warp %g bpred %g all %g\n",
+ time_arm_seconds, time_ups_seconds, time_syn_seconds, time_blend_seconds, time_warp_seconds, time_bpred_seconds, time_all_seconds);
fflush(stdout);
}
if (fout != NULL)
+ {
fclose(fout);
-
- printf("decode done\n");
+ printf("%s created\n", out_filename.c_str());
+ }
return 0;
}
-#if 0
+#ifdef CCDEC_EXE
+
+// A main() and associated parameters for standalone executable.
+#ifndef CCDECAPI_AVX2_OPTIONAL
+// we are expecting 'optional' -- we have a run-time parameter for choosing cpu or avx2 or auto.
+#error CCDEC_EXE needs CCDECAPI_AVX2_OPTIONAL
+no.
+#endif
+
char const *param_bitstream = "--input=";
char const *param_out = "--output=";
+char const *param_bitdepth = "--output_bitdepth=";
+char const *param_chroma = "--output_chroma_format=";
+char const *param_cpu = "--cpu";
+char const *param_auto = "--auto";
+char const *param_avx2 = "--avx2";
+char const *param_v = "--v=";
void usage(char const *msg = NULL)
{
@@ -812,7 +881,11 @@ void usage(char const *msg = NULL)
printf("%s\n", msg);
printf("Usage:\n");
printf(" %s: .hevc to decode\n", param_bitstream);
- printf(" %s: reconstruction if desired: ppm (image) or yuv420 (video) only\n", param_out);
+ printf(" [%s]: reconstruction if desired: ppm (image) or yuv420 (video) only\n", param_out);
+ printf(" [%s]: output bitdepth 0 => take from bitstream\n", param_bitdepth);
+ printf(" [%s]: output chroma subsampling 420 or 444 for yuv; 0 => take from bitstream\n", param_chroma);
+ printf(" [%s|%s|%s]: optimized instruction set with which to decode\n", param_cpu, param_avx2, param_auto);
+ printf(" [%s]: verbosity\n", param_v);
exit(msg != NULL);
}
@@ -821,6 +894,12 @@ int main(int argc, const char* argv[])
std::string bitstream_filename;
std::string out = "out";
+ int output_bitdepth = 0;
+ int output_chroma_format = 0;
+ bool explicit_instruction_set = false;
+ bool use_avx2 = false;
+ bool use_auto = false;
+ int verbosity = 0;
for (int i = 1; i < argc; i++)
{
@@ -828,6 +907,27 @@ int main(int argc, const char* argv[])
bitstream_filename = argv[i]+strlen(param_bitstream);
else if (strncmp(argv[i], param_out, strlen(param_out)) == 0)
out = argv[i]+strlen(param_out);
+ else if (strncmp(argv[i], param_bitdepth, strlen(param_bitdepth)) == 0)
+ output_bitdepth = atoi(argv[i]+strlen(param_bitdepth));
+ else if (strncmp(argv[i], param_chroma, strlen(param_chroma)) == 0)
+ output_chroma_format = atoi(argv[i]+strlen(param_chroma));
+ else if (strncmp(argv[i], param_avx2, strlen(param_avx2)) == 0
+ || strncmp(argv[i], param_cpu, strlen(param_cpu)) == 0
+ || strncmp(argv[i], param_auto, strlen(param_auto)) == 0)
+ {
+ if (explicit_instruction_set)
+ {
+ printf("%s: only a single --auto, --avx2 or --cpu\n", argv[i]);
+ exit(1);
+ }
+ explicit_instruction_set = true;
+ use_avx2 = strncmp(argv[i], param_avx2, strlen(param_avx2)) == 0;
+ use_auto = strncmp(argv[i], param_auto, strlen(param_auto)) == 0;
+ }
+ else if (strncmp(argv[i], param_v, strlen(param_v)) == 0)
+ {
+ verbosity = atoi(argv[i]+strlen(param_v));
+ }
else
{
std::string error_message = "unknown parameter ";
@@ -839,12 +939,16 @@ int main(int argc, const char* argv[])
if (bitstream_filename == "")
usage("must specify a bitstream");
-#ifdef CCDECAPI_CPU
- int result = cc_decode_cpu(bitstream_filename, out);
-#else
- int result = cc_decode_avx2(bitstream_filename, out);
-#endif
- printf("decode exit code: %d\n", result);
+ if (!explicit_instruction_set)
+ use_auto = true;
+ if (use_auto)
+ {
+ if (__builtin_cpu_supports("avx2"))
+ {
+ use_avx2 = true;
+ }
+ }
+ int result = cc_decode_avx2_optional(bitstream_filename, out, output_bitdepth, output_chroma_format, use_avx2, verbosity);
return result;
}
#endif
diff --git a/coolchic/cpp/ccdecapi_avx2.cpp b/coolchic/cpp/ccdecapi_avx2.cpp
index adb0f078..9f9fd3b6 100644
--- a/coolchic/cpp/ccdecapi_avx2.cpp
+++ b/coolchic/cpp/ccdecapi_avx2.cpp
@@ -19,13 +19,18 @@ namespace py = pybind11;
// encode latents layer to a file.
int cc_decode_avx2(
std::string &in_bitstream_filename,
- std::string &out_ppm_filename);
+ std::string &out_ppm_filename,
+ int output_bitdepth = 0,
+ int output_chroma_format = 0,
+ int verbosity = 0);
PYBIND11_MODULE(ccdecapi_avx2, m) {
m.doc() = "ccdecoding"; // optional module docstring
m.def("cc_decode_avx2", &cc_decode_avx2, "decode a bitstream");
}
+#ifndef CCDECAPI_AVX2
#define CCDECAPI_AVX2
-#include "ccdecapi.hpp"
+#endif
+#include "ccdecapi.cpp"
#undef CCDECAPI_AVX2
diff --git a/coolchic/cpp/ccdecapi_cpu.cpp b/coolchic/cpp/ccdecapi_cpu.cpp
index 918c8c44..e5a7bbca 100644
--- a/coolchic/cpp/ccdecapi_cpu.cpp
+++ b/coolchic/cpp/ccdecapi_cpu.cpp
@@ -19,13 +19,18 @@ namespace py = pybind11;
// encode latents layer to a file.
int cc_decode_cpu(
std::string &in_bitstream_filename,
- std::string &out_ppm_filename);
+ std::string &out_ppm_filename,
+ int output_bitdepth = 0,
+ int output_chroma_format = 0,
+ int verbosity = 0);
PYBIND11_MODULE(ccdecapi_cpu, m) {
m.doc() = "ccdecoding"; // optional module docstring
m.def("cc_decode_cpu", &cc_decode_cpu, "decode a bitstream");
}
+#ifndef CCDECAPI_CPU
#define CCDECAPI_CPU
-#include "ccdecapi.hpp"
+#endif
+#include "ccdecapi.cpp"
#undef CCDECAPI_CPU
diff --git a/coolchic/cpp/ccencapi.cpp b/coolchic/cpp/ccencapi.cpp
index 08272ec0..ce255184 100644
--- a/coolchic/cpp/ccencapi.cpp
+++ b/coolchic/cpp/ccencapi.cpp
@@ -96,6 +96,7 @@ void code_val(TEncBinCABAC &layer_BAC, MuSigGTs *coding_ctxs, int val_to_code)
// we return the best index.
int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count)
{
+#if 0
// !!! check for all zero, emit empty file.
bool all_zero = true;
for (int i = 0; i < (int)xs.size(); i++)
@@ -106,10 +107,11 @@ int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count)
break;
}
}
- if (all_zero)
- {
- printf("all weights/biases zero -- think of empty file\n");
- }
+ //if (all_zero)
+ //{
+ // printf("all weights/biases zero -- think of empty file\n");
+ //}
+#endif
TEncBinCABAC layer_BAC;
@@ -120,6 +122,16 @@ int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count)
int test_max = 12;
if (use_count >= 0)
test_min = test_max = use_count;
+#if 0
+ if (1) // xs.size() == 5 || xs.size() == 8)
+ {
+ printf("encoding weight/bias vector %d to %s:", xs.size(), out_name.c_str());
+ for (int i = 0; i < xs.size(); i++)
+ printf(" %d", xs[i]);
+ printf("\n");
+ fflush(stdout);
+ }
+#endif
for (int exgolomb_count = test_min; exgolomb_count <= test_max; exgolomb_count++)
{
//auto layer_BAC = CABACEncoder();
@@ -141,7 +153,7 @@ int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count)
{
best_exgolomb_count = exgolomb_count;
best_exgolomb_bytes = bsBAC.getFifo();
- printf("better exgolomb bytes %d at count=%d\n", (int)best_exgolomb_bytes.size(), best_exgolomb_count);
+ //printf("better exgolomb bytes %d at count=%d\n", (int)best_exgolomb_bytes.size(), best_exgolomb_count);
}
}
@@ -158,7 +170,7 @@ int cc_code_wb_bac(std::string &out_name, std::vector &xs, int use_count)
exit(1);
}
fclose(fout);
- printf("%s created\n", out_name.c_str());
+ //printf("%s created\n", out_name.c_str());
// return best sig index used for coding
return best_exgolomb_count;
@@ -172,8 +184,6 @@ void cc_code_latent_layer_bac(
int layer_height, int layer_width,
int hls_sig_blksize)
{
- printf("called cc_code_latent_layer_bac: file=%s\n", out_name.c_str());
-
// get significant blocks.
bool hls_sig_update = hls_sig_blksize < 0;
if (hls_sig_update)
@@ -245,7 +255,7 @@ void cc_code_latent_layer_bac(
// want to bother? For significant blocks, we now say no.
// SIG
- printf("nz %d vs %d\n", n_zero, nby*nbx);
+ // printf("nz %d vs %d\n", n_zero, nby*nbx);
//if (n_zero <= nby*nbx/20)
if (1) // no longer use significance blocks.
{
@@ -259,7 +269,7 @@ void cc_code_latent_layer_bac(
else
{
// signal block significance.
- printf("sig1 %s\n", hls_sig_update ? "(update)" : "(noupdate)");
+ // printf("sig1 %s\n", hls_sig_update ? "(update)" : "(noupdate)");
layer_BAC.encodeBinEP(1);
auto ctx = BinProbModel_Std(PROBA_50_STATE);
for (int by = 0; by < nby; by++)
@@ -281,7 +291,7 @@ void cc_code_latent_layer_bac(
// FLAT?
- printf("nflat %d vs %d\n", n_flat, nby*nbx);
+ // printf("nflat %d vs %d\n", n_flat, nby*nbx);
if (n_flat <= nby*nbx/20)
{
layer_BAC.encodeBinEP(0);
@@ -292,7 +302,7 @@ void cc_code_latent_layer_bac(
layer_BAC.encodeBinEP(1);
// signal flat for sig blocks.
auto ctx = BinProbModel_Std(PROBA_50_STATE);
- printf("flat1\n");
+ // printf("flat1\n");
for (int by = 0; by < nby; by++)
{
for (int bx = 0; bx < nbx; bx++)
@@ -367,7 +377,7 @@ void cc_code_latent_layer_bac(
exit(1);
}
fclose(fout);
- printf("%s created\n", out_name.c_str());
+ // printf("%s created\n", out_name.c_str());
delete[] blk_sig;
delete[] blk_flat;
diff --git a/coolchic/cpp/frame-memory.cpp b/coolchic/cpp/frame-memory.cpp
index 9428c818..6861eb0e 100644
--- a/coolchic/cpp/frame-memory.cpp
+++ b/coolchic/cpp/frame-memory.cpp
@@ -1,24 +1,24 @@
#include "frame-memory.h"
-void frame_memory::custom_pad_replicate_plane_in_place_i(int plane, int pad)
+void frame_memory::custom_pad_replicate_plane_in_place_i(int plane, int padlr, int padtb)
{
int32_t *scan_in = plane_origin(plane);
- int32_t *scan_out = scan_in - pad*stride-pad;
- int w_in_padded = w+2*pad;
+ int32_t *scan_out = scan_in - padtb*stride-padlr;
+ int w_in_padded = w+2*padlr;
// leading lines
- for (int y = 0; y < pad; y++)
+ for (int y = 0; y < padtb; y++)
{
if (y == 0)
{
// construct first padded line.
- for (int x = 0; x < pad; x++)
+ for (int x = 0; x < padlr; x++)
{
*scan_out++ = *scan_in;
}
memcpy(scan_out, scan_in, w*sizeof(scan_out[0]));
scan_out += w;
- for (int x = 0; x < pad; x++)
+ for (int x = 0; x < padlr; x++)
{
*scan_out++ = scan_in[w-1];
}
@@ -31,31 +31,74 @@ void frame_memory::custom_pad_replicate_plane_in_place_i(int plane, int pad)
scan_out += stride;
}
}
+
// internal lines: pad left and right edges.
- for (int y = 0; y < h; y++)
+ if (padlr > 0)
{
- scan_out = scan_in-pad;
- for (int x = 0; x < pad; x++)
- {
- *scan_out++ = *scan_in;
- }
- scan_out += w;
- for (int x = 0; x < pad; x++)
+ for (int y = 0; y < h; y++)
{
- *scan_out++ = scan_in[w-1];
+ scan_out = scan_in-padlr;
+ for (int x = 0; x < padlr; x++)
+ {
+ *scan_out++ = *scan_in;
+ }
+ scan_out += w;
+ for (int x = 0; x < padlr; x++)
+ {
+ *scan_out++ = scan_in[w-1];
+ }
+ scan_in += stride;
+ scan_out += stride-w_in_padded;
}
- scan_in += stride;
+ }
+ else
+ {
+ scan_in += h*stride;
+ scan_out += h*stride;
}
// trailing lines -- we copy the previous padded.
scan_in -= stride; // beginning of last line, padding to the left.
- scan_out = scan_in+stride-pad;
- for (int y = 0; y < pad; y++)
+ scan_out = scan_in+stride-padlr;
+ for (int y = 0; y < padtb; y++)
{
memcpy(scan_out, scan_out-stride, w_in_padded*sizeof(scan_out[0]));
scan_out += stride;
}
}
+void frame_memory::custom_pad_zero_plane_in_place_i(int plane, int padlr, int padtb)
+{
+ int32_t *scan_out = plane_origin(plane) - padtb*stride-padlr;
+ int w_in_padded = w+2*padlr;
+ // leading lines
+ for (int y = 0; y < padtb; y++)
+ {
+ memset(scan_out, 0, w_in_padded*sizeof(scan_out[0]));
+ scan_out += stride;
+ }
+
+ // internal lines: pad left and right edges.
+ if (padlr > 0)
+ {
+ for (int y = 0; y < h; y++)
+ {
+ memset(scan_out, 0, padlr*sizeof(scan_out[0]));
+ memset(scan_out+padlr+w, 0, padlr*sizeof(scan_out[0]));
+ scan_out += stride;
+ }
+ }
+ else
+ {
+ scan_out += h*stride;
+ }
+
+ for (int y = 0; y < padtb; y++)
+ {
+ memset(scan_out, 0, w_in_padded*sizeof(scan_out[0]));
+ scan_out += stride;
+ }
+}
+
void frame_memory::zero_pad(int plane, int pad)
{
// !!! just the padding.
diff --git a/coolchic/cpp/frame-memory.h b/coolchic/cpp/frame-memory.h
index 0b55c1a3..f0308dea 100644
--- a/coolchic/cpp/frame-memory.h
+++ b/coolchic/cpp/frame-memory.h
@@ -7,6 +7,8 @@ struct frame_memory
int origin_idx;
int h;
int w;
+ int pad;
+ int planes;
int stride;
int plane_stride;
public:
@@ -14,15 +16,19 @@ struct frame_memory
origin_idx(0),
h(0),
w(0),
+ pad(0),
+ planes(0),
stride(0),
plane_stride(0)
{};
~frame_memory() {};
public:
- void update_to(int frame_h, int frame_w, int frame_planes, int frame_pad)
+ void update_to(int frame_h, int frame_w, int frame_pad, int frame_planes)
{
h = frame_h;
w = frame_w;
+ pad = frame_pad;
+ planes = frame_planes;
stride = frame_pad+frame_w+frame_pad;
plane_stride = (frame_pad+frame_h+frame_pad)*stride;
origin_idx = frame_pad*stride+frame_pad;
@@ -30,9 +36,10 @@ struct frame_memory
}
int32_t *raw() { return raw_frame.data; }
int32_t *origin() { return raw_frame.data+origin_idx; }
- int32_t *pad_origin(int pad) { return raw_frame.data+origin_idx - pad*stride - pad; }
int32_t *plane_origin(int plane = 0) { return origin()+plane*plane_stride; }
- int32_t *pad_origin(int plane, int pad) { return raw_frame.data + origin_idx-pad*stride-pad + plane*plane_stride; }
+ int32_t *plane_pixel(int plane, int y, int x) { return plane_origin(plane)+y*stride+x; }
+ int32_t *pad_origin(int plane, int pad) { return pad_origin(plane, pad, pad); }
+ int32_t *pad_origin(int plane, int padlr, int padtb) { return raw_frame.data + origin_idx-padtb*stride-padlr + plane*plane_stride; }
void print_ranges(char const *msg, int nplanes, int precision)
{
int minval = plane_origin(0)[0];
@@ -48,9 +55,22 @@ struct frame_memory
}
printf("%s: minval=%g maxval=%g\n", msg == NULL ? "" : msg, (float)minval/(1< 0 ? n : h, n > 0 ? n : w, msg == NULL ? "" : msg);
+ for (int y = 0; y < (n > 0 ? n : h); y++)
+ {
+ for (int x = 0; x < (n > 0 ? n : w); x++)
+ printf(" %g", plane_origin(plane)[y*stride+x]/((1<
-
-#include "TDecBinCoderCABAC.h"
-#include "cc-contexts.h"
-#include "common.h"
-#include "rest_avx2.h"
-
-// we have timing stuff here as well.
-#include // timing.
-
-// FLOAT
-float hsum_ps_sse3(__m128 v) {
- __m128 shuf = _mm_movehdup_ps(v); // broadcast elements 3,1 to 2,0
- __m128 sums = _mm_add_ps(v, shuf);
- shuf = _mm_movehl_ps(shuf, sums); // high half -> low half
- sums = _mm_add_ss(sums, shuf);
- return _mm_cvtss_f32(sums);
-}
-
-float hsum256_ps_avx(__m256 v) {
- __m128 vlow = _mm256_castps256_ps128(v);
- __m128 vhigh = _mm256_extractf128_ps(v, 1); // high 128
- vlow = _mm_add_ps(vlow, vhigh); // add the low 128
- return hsum_ps_sse3(vlow); // and inline the sse3 version, which is optimal for AVX
- // (no wasted instructions, and all of them are the 4B minimum)
-}
-
-// INT
-uint32_t hsum_epi32_avx(__m128i x)
-{
- __m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a movdqa
- __m128i sum64 = _mm_add_epi32(hi64, x);
- __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); // Swap the low two elements
- __m128i sum32 = _mm_add_epi32(sum64, hi32);
- return _mm_cvtsi128_si32(sum32); // movd
-}
-
-// only needs AVX2
-uint32_t hsum_8x32(__m256i v)
-{
- __m128i sum128 = _mm_add_epi32(
- _mm256_castsi256_si128(v),
- _mm256_extracti128_si256(v, 1)); // silly GCC uses a longer AXV512VL instruction if AVX512 is enabled :/
- return hsum_epi32_avx(sum128);
-}
-
-// we are applying 4 4x4 filters on a non-expanded source.
-//
-void custom_conv_ups_4x41_avx2(float *kw, int h_in, int w_in, float *in, float *out)
-{
- int const ks = 4;
- int const stride = 1;
- int offs0 = 0;
-
- int32_t indexes[] = { 0, 1, 2, 3, w_in, w_in+1, w_in+2, w_in+3 };
- __m256i ind_intel = _mm256_loadu_si256((__m256i *)&indexes[0]);
-
- __m256 kernel_v0h0_top = _mm256_loadu_ps(kw+(0*2+0)*(ks*ks)+0);
- __m256 kernel_v0h0_bot = _mm256_loadu_ps(kw+(0*2+0)*(ks*ks)+8);
- __m256 kernel_v0h1_top = _mm256_loadu_ps(kw+(0*2+1)*(ks*ks)+0);
- __m256 kernel_v0h1_bot = _mm256_loadu_ps(kw+(0*2+1)*(ks*ks)+8);
-
- __m256 kernel_v1h0_top = _mm256_loadu_ps(kw+(1*2+0)*(ks*ks)+0);
- __m256 kernel_v1h0_bot = _mm256_loadu_ps(kw+(1*2+0)*(ks*ks)+8);
- __m256 kernel_v1h1_top = _mm256_loadu_ps(kw+(1*2+1)*(ks*ks)+0);
- __m256 kernel_v1h1_bot = _mm256_loadu_ps(kw+(1*2+1)*(ks*ks)+8);
-
-
- for (int y = 0; y < h_in-ks+1; y += stride, offs0 += w_in)
- {
- int offs = offs0;
- float *out_next = out+2*(w_in-ks+1);
-
- for (int x = 0; x < w_in-ks+1; x += stride, offs += stride)
- {
- __m256 in_0 = _mm256_i32gather_ps(&in[offs+0*w_in], ind_intel, sizeof(float)); // 8 32-bit floats
- __m256 in_1 = _mm256_i32gather_ps(&in[offs+2*w_in], ind_intel, sizeof(float)); // 8 32-bit floats
-
- __m256 sum_0 = kernel_v0h0_top*in_0;
- sum_0 += kernel_v0h0_bot*in_1;
- __m256 sum_1 = kernel_v0h1_top*in_0;
- sum_1 += kernel_v0h1_bot*in_1;
-
- *out++ = hsum256_ps_avx(sum_0);
- *out++ = hsum256_ps_avx(sum_1);
-
- sum_0 = kernel_v1h0_top*in_0;
- sum_0 += kernel_v1h0_bot*in_1;
- sum_1 = kernel_v1h1_top*in_0;
- sum_1 += kernel_v1h1_bot*in_1;
-
- *out_next++ = hsum256_ps_avx(sum_0);
- *out_next++ = hsum256_ps_avx(sum_1);
- }
-
- // we've done two output lines.
- out = out_next;
- }
-}
diff --git a/coolchic/cpp/rest_avx2.h b/coolchic/cpp/rest_avx2.h
deleted file mode 100644
index 01e7db8a..00000000
--- a/coolchic/cpp/rest_avx2.h
+++ /dev/null
@@ -1,11 +0,0 @@
-/*
- Software Name: Cool-Chic
- SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
- SPDX-License-Identifier: BSD 3-Clause "New"
-
- This software is distributed under the BSD-3-Clause license.
- Authors: see CONTRIBUTORS.md
-*/
-
-
-void custom_conv_ups_4x41_avx2(float *kw, int h_in, int w_in, float *in, float *out);
diff --git a/coolchic/cpp/syn_avx2.cpp b/coolchic/cpp/syn_avx2.cpp
index afea11e5..276e9edc 100644
--- a/coolchic/cpp/syn_avx2.cpp
+++ b/coolchic/cpp/syn_avx2.cpp
@@ -99,6 +99,13 @@
#define SYN_NAME custom_conv_ksX_inX_outX_avx2
#include "syn_avx2.hpp"
+#define SYN_NAME custom_conv_ks1_in7_hidden48_out3_avx2
+#define SYN_KS 1
+#define SYN_N_IN 7
+#define SYN_N_HIDDEN 48
+#define SYN_N_OUT 3
+#include "synfused_avx2.hpp"
+
#define SYN_NAME custom_conv_ks1_in7_hidden40_out3_avx2
#define SYN_KS 1
#define SYN_N_IN 7
@@ -120,6 +127,13 @@
#define SYN_N_OUT 9
#include "synfused_avx2.hpp"
+#define SYN_NAME custom_conv_ks1_in7_hidden32_out3_avx2
+#define SYN_KS 1
+#define SYN_N_IN 7
+#define SYN_N_HIDDEN 32
+#define SYN_N_OUT 3
+#include "synfused_avx2.hpp"
+
#define SYN_NAME custom_conv_ks1_in7_hidden16_out3_avx2
#define SYN_KS 1
#define SYN_N_IN 7
@@ -161,3 +175,5 @@
#define SYN_NAME custom_conv_ks3_inX_outX_lb_avx2
#define SYN_KS 3
#include "synlb_avx2.hpp"
+
+#include "synblend_avx2.hpp"
diff --git a/coolchic/cpp/syn_avx2.h b/coolchic/cpp/syn_avx2.h
index 15e0cf3e..4b217a92 100644
--- a/coolchic/cpp/syn_avx2.h
+++ b/coolchic/cpp/syn_avx2.h
@@ -22,9 +22,11 @@ void custom_conv_ks1_inX_out9_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, i
void custom_conv_ksX_inX_outX_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int residue, int relu);
+void custom_conv_ks1_in7_hidden48_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out);
void custom_conv_ks1_in7_hidden40_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out);
void custom_conv_ks1_in7_hidden40_out6_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out);
void custom_conv_ks1_in7_hidden40_out9_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out);
+void custom_conv_ks1_in7_hidden32_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out);
void custom_conv_ks1_in7_hidden16_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out);
void custom_conv_ks1_in7_hidden8_out3_avx2(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out);
@@ -33,3 +35,6 @@ void custom_conv_ks3_in6_out6_lb_avx2(int KS, int32_t *kw, int32_t *kb, int h_in
void custom_conv_ks3_in9_out6_lb_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu);
void custom_conv_ks3_in9_out9_lb_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu);
void custom_conv_ks3_inX_outX_lb_avx2(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu);
+
+void syn_blend1_avx2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out);
+void syn_blend2_avx2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out, int32_t blend_val_out);
diff --git a/coolchic/cpp/syn_avx2.hpp b/coolchic/cpp/syn_avx2.hpp
index 8d8d5816..2f2c1886 100644
--- a/coolchic/cpp/syn_avx2.hpp
+++ b/coolchic/cpp/syn_avx2.hpp
@@ -61,21 +61,10 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
in_layer[0] = in;
for (int i = 1; i < std::max(n_in, n_out); i++)
in_layer[i] = in_layer[i-1]+plane_stride_in;
-//#if SYN_KS == 1
-//// possible in-place for ks==1 -- we do not check here.
-//#define out_layer in_layer
-// if (out != NULL && out != in)
-// {
-// printf("%s: ks=%d n_in=%d n_out=%d: bad call: should be in-place, but out supplied\n", xstr(SYN_NAME), KS, N_IN, N_OUT);
-// exit(1);
-// }
-//#else
- //not in-place
int32_t *out_layer[N_OUT];
out_layer[0] = out;
for (int i = 1; i < N_OUT; i++)
out_layer[i] = out_layer[i-1]+plane_stride_in;
-//#endif
const __m256i z = _mm256_setzero_si256();
@@ -159,32 +148,44 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
}
}
}
+ int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size;
+ int32_t store[8];
if (relu)
{
+ // -ves to zero, >> for rest.
for (int ol = 0; ol < atatime; ol++)
{
out_avx2[ol] = _mm256_blendv_epi8(z, out_avx2[ol], _mm256_cmpgt_epi32(out_avx2[ol], z));
- }
- }
- if (x_blk < xlim_blk-1)
- {
- for (int ol = 0; ol < atatime; ol++)
- {
out_avx2[ol] = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION);
- _mm256_storeu_si256((__m256i_u*)&out_layer[olbase+ol][offso], out_avx2[ol]);
+ if (n_outs == 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)&out_layer[olbase+ol][offso], out_avx2[ol]);
+ }
+ else
+ {
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]);
+ memcpy(&out_layer[olbase+ol][offso], &store[0], n_outs*sizeof(store[0]));
+ }
}
}
else
{
- int n_outs = xlast_blk_size;
- // partial last line.
+ // need different treatment for -ve and +ve.
for (int ol = 0; ol < atatime; ol++)
{
- int32_t store[8];
- out_avx2[ol] = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION);
- _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]);
- memcpy(&out_layer[olbase+ol][offso], &store[0], n_outs*sizeof(store[0]));
+ __m256i sr = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION);
+ __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_avx2[ol]), SYN_MUL_PRECISION));
+ out_avx2[ol] = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_avx2[ol], z));
+ if (n_outs == 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)&out_layer[olbase+ol][offso], out_avx2[ol]);
+ }
+ else
+ {
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]);
+ memcpy(&out_layer[olbase+ol][offso], &store[0], n_outs*sizeof(store[0]));
+ }
}
}
} // olbase
diff --git a/coolchic/cpp/syn_cpu.cpp b/coolchic/cpp/syn_cpu.cpp
index 7169ae33..d29117a9 100644
--- a/coolchic/cpp/syn_cpu.cpp
+++ b/coolchic/cpp/syn_cpu.cpp
@@ -10,7 +10,6 @@
#include
#include
-#include // !!! abs
#include
#include "common.h"
@@ -30,3 +29,5 @@
#define SYN_NAME custom_conv_ks3_inX_outX_lb
#define SYN_KS 3
#include "synlb_cpu.hpp"
+
+#include "synblend_cpu.hpp"
diff --git a/coolchic/cpp/syn_cpu.h b/coolchic/cpp/syn_cpu.h
index 6998fdd7..526c1210 100644
--- a/coolchic/cpp/syn_cpu.h
+++ b/coolchic/cpp/syn_cpu.h
@@ -13,3 +13,6 @@ void custom_conv_ksX_inX_outX(int KS, int32_t *kw, int32_t *kb, int h_in, int w_
void custom_conv_ks1_inX_hiddenX_outX(int KS, int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3, int h_in, int w_in, int pad_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out);
void custom_conv_ks3_inX_outX_lb(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu);
+
+void syn_blend1(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out);
+void syn_blend2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out, int32_t blend_val_out);
diff --git a/coolchic/cpp/syn_cpu.hpp b/coolchic/cpp/syn_cpu.hpp
index 56e55bb6..4ca52901 100644
--- a/coolchic/cpp/syn_cpu.hpp
+++ b/coolchic/cpp/syn_cpu.hpp
@@ -20,7 +20,7 @@
// stride and plane_stride are assumed the same for in and out.
void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int residue, int relu)
{
- printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu);
+ //printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu);
int const kstride = 1;
#ifdef SYN_KS
@@ -90,9 +90,16 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
sum += xxres;
}
}
- sum >>= SYN_MUL_PRECISION; // take multiplied sum to output. // !!! check sign?
- if (relu && sum < 0)
- sum = 0;
+ // take multiplied sum to output after reluing.
+ if (sum < 0)
+ {
+ if (relu)
+ sum = 0;
+ else
+ sum = -(-sum >> SYN_MUL_PRECISION);
+ }
+ else
+ sum >>= SYN_MUL_PRECISION;
out_cache[ol] = sum;
}
// flush.
diff --git a/coolchic/cpp/synblend_avx2.hpp b/coolchic/cpp/synblend_avx2.hpp
new file mode 100644
index 00000000..f428ce05
--- /dev/null
+++ b/coolchic/cpp/synblend_avx2.hpp
@@ -0,0 +1,12 @@
+
+void syn_blend1_avx2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val, int32_t *out)
+{
+ printf("should not be here\n");
+ exit(1);
+}
+
+void syn_blend2_avx2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out, int32_t blend_val_out)
+{
+ printf("should not be here\n");
+ exit(1);
+}
diff --git a/coolchic/cpp/synblend_cpu.hpp b/coolchic/cpp/synblend_cpu.hpp
new file mode 100644
index 00000000..21d4dd85
--- /dev/null
+++ b/coolchic/cpp/synblend_cpu.hpp
@@ -0,0 +1,50 @@
+
+// out = out + in*blend_val
+void syn_blend1(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val, int32_t *out)
+{
+ for (int p = 0; p < N_INOUT; p++)
+ {
+ int32_t *src = in+plane_stride_in*p;
+ int32_t *dst = out+plane_stride_in*p;
+ for (int y = 0; y < h_in; y++, src += stride_in-w_in, dst += stride_in-w_in)
+ {
+ for (int x = 0; x < w_in; x++, src++, dst++)
+ {
+ int x0 = *src;
+ if (x0 < 0)
+ x0 = 0;
+ else if (x0 > (1<> SYN_LAYER_PRECISION);
+ }
+ }
+ }
+}
+
+// out = out*blend_val_out + in*blend_val_in
+void syn_blend2(int h_in, int w_in, int stride_in, int plane_stride_in, int N_INOUT, int32_t *in, int32_t blend_val_in, int32_t *out, int32_t blend_val_out)
+{
+ for (int p = 0; p < N_INOUT; p++)
+ {
+ int32_t *src = in+plane_stride_in*p;
+ int32_t *dst = out+plane_stride_in*p;
+ for (int y = 0; y < h_in; y++, src += stride_in-w_in, dst += stride_in-w_in)
+ {
+ for (int x = 0; x < w_in; x++, src++, dst++)
+ {
+ int x0 = *src;
+ if (x0 < 0)
+ x0 = 0;
+ else if (x0 > (1< (1<> SYN_LAYER_PRECISION;
+ }
+ }
+ }
+}
diff --git a/coolchic/cpp/synfused_avx2.hpp b/coolchic/cpp/synfused_avx2.hpp
index 2b602086..d22d8e6b 100644
--- a/coolchic/cpp/synfused_avx2.hpp
+++ b/coolchic/cpp/synfused_avx2.hpp
@@ -13,12 +13,12 @@
// N_IN (7)
// N_HIDDEN (16 or 40)
// N_OUT (3)
-// HIDDEN always RELU, OUT always NONE
+// HIDDEN always RELU
void SYN_NAME(int KS,
int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3,
int h_in, int w_in, int stride_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out)
{
- printf("%s(ks=%d N_IN=%d N_HIDDEN=%d N_OUT=%d\n", xstr(SYN_NAME), KS, N_IN, N_HIDDEN, N_OUT);
+ //printf("%s(ks=%d N_IN=%d N_HIDDEN=%d N_OUT=%d\n", xstr(SYN_NAME), KS, N_IN, N_HIDDEN, N_OUT);
#ifdef SYN_KS
int const ks = SYN_KS;
@@ -46,6 +46,7 @@ void SYN_NAME(int KS,
int const n_hidden8 = n_hidden/8;
int32_t *src = in;
+ int32_t *dst = out;
if (KS != ks)
{
@@ -71,8 +72,8 @@ void SYN_NAME(int KS,
const __m256i rotate_right = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 0);
const __m256i z = _mm256_setzero_si256();
- for (int y = 0; y < h_in; y++, src += src_pad+src_pad) // eol of this, and bol of next.
- for (int x = 0; x < w_in; x++, src++)
+ for (int y = 0; y < h_in; y++, src += src_pad+src_pad, dst += src_pad+src_pad) // eol of this, and bol of next.
+ for (int x = 0; x < w_in; x++, src++, dst++)
{
__m256i input_avx2_src = _mm256_i32gather_epi32(&src[0], ind_intel_0, sizeof(int32_t));
__m256i hidden_avx2[n_hidden/8];
@@ -123,8 +124,11 @@ void SYN_NAME(int KS,
// horizontal sum
int32_t sum = kb[ol] + hsum_8x32(out_avx2);
- sum >>= SYN_MUL_PRECISION; // take multiplied sum to output. // !!! check sign?
- src[ol*plane_stride_in] = sum;
+ if (sum < 0)
+ sum = -(-sum >> SYN_MUL_PRECISION);
+ else
+ sum >>= SYN_MUL_PRECISION;
+ dst[ol*plane_stride_in] = sum;
}
} // x, y
}
diff --git a/coolchic/cpp/synfused_cpu.hpp b/coolchic/cpp/synfused_cpu.hpp
index 2c90932d..849b0a8b 100644
--- a/coolchic/cpp/synfused_cpu.hpp
+++ b/coolchic/cpp/synfused_cpu.hpp
@@ -13,12 +13,12 @@
// N_IN (7)
// N_HIDDEN (16 or 40)
// N_OUT (3)
-// HIDDEN always RELU, OUT always NONE
+// HIDDEN always RELU
void SYN_NAME(int KS,
int32_t *kw7_40, int32_t *kb40, int32_t *kw40_3, int32_t *kb3,
int h_in, int w_in, int stride_in, int plane_stride_in, int N_IN, int N_HIDDEN, int32_t *in, int N_OUT, int32_t *out)
{
- printf("%s(ks=%d N_IN=%d N_HIDDEN=%d N_OUT=%d\n", xstr(SYN_NAME), KS, N_IN, N_HIDDEN, N_OUT);
+ //printf("%s(ks=%d N_IN=%d N_HIDDEN=%d N_OUT=%d\n", xstr(SYN_NAME), KS, N_IN, N_HIDDEN, N_OUT);
#ifdef SYN_KS
int const ks = SYN_KS;
@@ -40,7 +40,6 @@ void SYN_NAME(int KS,
#else
int const n_out = N_OUT;
#endif
- int const pad_in = (stride_in-w_in)/2;
if (KS != ks)
{
@@ -48,11 +47,12 @@ void SYN_NAME(int KS,
exit(1);
}
int32_t *src = in;
+ int32_t *dst = out;
int32_t elements_7[n_in]; // in
int32_t elements_40[n_hidden]; // hidden
- for (int y = 0; y < h_in; y++, src += pad_in+pad_in) // pads are: eol of this, and bol of next.
- for (int x = 0; x < w_in; x++, src++)
+ for (int y = 0; y < h_in; y++, src += stride_in-w_in, dst += stride_in-w_in)
+ for (int x = 0; x < w_in; x++, src++, dst++)
{
int32_t *inputs = &elements_7[0];
int32_t *hidden = &elements_40[0];
@@ -98,8 +98,12 @@ void SYN_NAME(int KS,
for (int il = 0; il < n_hidden; il++)
sum += hidden[il]*kw[il];
// no relu
- sum >>= SYN_MUL_PRECISION; // take multiplied sum to output. // !!! check sign?
- src[ol*plane_stride_in] = sum;
+ // take multiplied sum to output.
+ if (sum < 0)
+ sum = -(-sum >> SYN_MUL_PRECISION);
+ else
+ sum >>= SYN_MUL_PRECISION;
+ dst[ol*plane_stride_in] = sum;
}
} // x, y
}
diff --git a/coolchic/cpp/synlb_avx2.hpp b/coolchic/cpp/synlb_avx2.hpp
index 8423de88..ac3be67f 100644
--- a/coolchic/cpp/synlb_avx2.hpp
+++ b/coolchic/cpp/synlb_avx2.hpp
@@ -55,13 +55,17 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
#endif
#endif
- int32_t *in_layer[std::max(n_in, n_out)];
+ int32_t *in_layer[n_in];
in_layer[0] = in;
- for (int i = 1; i < std::max(n_in, n_out); i++)
+ for (int i = 1; i < n_in; i++)
in_layer[i] = in_layer[i-1]+plane_stride_in;
+ int32_t *out_layer[n_out];
+ out_layer[0] = out;
+ for (int i = 1; i < n_out; i++)
+ out_layer[i] = out_layer[i-1]+plane_stride_in;
+
// in-place, must have line buffer.
- int32_t **out_layer;
int32_t *lb[2]; // two line buffer pointers.
int h_out = h_in-ks+1;
int w_out = w_in-ks+1;
@@ -70,23 +74,14 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
printf("%s: bad call: in-place lb must have ks=3\n", xstr(SYN_NAME));
exit(1);
}
- if (out == NULL || out == in)
- {
- // must have a line buffer.
- if (line_buffer == NULL)
- {
- printf("%s: bad call, no line buffer supplied\n", xstr(SYN_NAME));
- exit(1);
- }
- out_layer = in_layer;
- lb[0] = line_buffer;
- lb[1] = line_buffer+w_out*n_out;
- }
- else
+ // must have a line buffer.
+ if (line_buffer == NULL)
{
- printf("%s: bad call should have lb and in-place\n", xstr(SYN_NAME));
+ printf("%s: bad call, no line buffer supplied\n", xstr(SYN_NAME));
exit(1);
}
+ lb[0] = line_buffer;
+ lb[1] = line_buffer+w_out*n_out;
const __m256i z = _mm256_setzero_si256();
@@ -170,32 +165,43 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
}
}
}
+ int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size;
+ int32_t store[8];
if (relu)
{
+ // -ves to zero, >> for rest.
for (int ol = 0; ol < atatime; ol++)
{
out_avx2[ol] = _mm256_blendv_epi8(z, out_avx2[ol], _mm256_cmpgt_epi32(out_avx2[ol], z));
- }
- }
- if (x_blk < xlim_blk-1)
- {
- for (int ol = 0; ol < atatime; ol++)
- {
out_avx2[ol] = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION);
- //_mm256_storeu_si256((__m256i_u*)&out_layer[olbase+ol][offso], out_avx2[ol]);
- _mm256_storeu_si256((__m256i_u*)&lb[y%2][(olbase+ol)*w_out+x_blk*8], out_avx2[ol]);
+ if (n_outs == 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)&lb[y%2][(olbase+ol)*w_out+x_blk*8], out_avx2[ol]);
+ }
+ else
+ {
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]);
+ memcpy(&lb[y%2][(olbase+ol)*w_out+x_blk*8], &store[0], n_outs*sizeof(store[0]));
+ }
}
}
else
{
- int n_outs = xlast_blk_size;
- // partial last line.
+ // need different treatment for -ve and +ve.
for (int ol = 0; ol < atatime; ol++)
{
- int32_t store[8];
- out_avx2[ol] = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION);
- _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]);
- memcpy(&lb[y%2][(olbase+ol)*w_out+x_blk*8], &store[0], n_outs*sizeof(store[0]));
+ __m256i sr = _mm256_srai_epi32(out_avx2[ol], SYN_MUL_PRECISION);
+ __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_avx2[ol]), SYN_MUL_PRECISION));
+ out_avx2[ol] = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_avx2[ol], z));
+ if (n_outs == 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)&lb[y%2][(olbase+ol)*w_out+x_blk*8], out_avx2[ol]);
+ }
+ else
+ {
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2[ol]);
+ memcpy(&lb[y%2][(olbase+ol)*w_out+x_blk*8], &store[0], n_outs*sizeof(store[0]));
+ }
}
}
} // olbase
@@ -205,12 +211,12 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
if (y >= 1)
{
for (int ol = 0; ol < n_out; ol++)
- memcpy(&out_layer[ol][offsi_base-stride_in+residue_origin_offset], &lb[(y-1)%2][ol*w_out], w_out*sizeof(int32_t));
+ memcpy(&out_layer[ol][offsi_base-stride_in], &lb[(y-1)%2][ol*w_out], w_out*sizeof(int32_t));
}
} // y
// flush final line.
for (int ol = 0; ol < n_out; ol++)
- memcpy(&out_layer[ol][offsi_base-stride_in+residue_origin_offset], &lb[(h_out-1)%2][ol*w_out], w_out*sizeof(int32_t));
+ memcpy(&out_layer[ol][offsi_base-stride_in], &lb[(h_out-1)%2][ol*w_out], w_out*sizeof(int32_t));
}
#undef tostr
diff --git a/coolchic/cpp/synlb_cpu.hpp b/coolchic/cpp/synlb_cpu.hpp
index 7fb53794..4092808e 100644
--- a/coolchic/cpp/synlb_cpu.hpp
+++ b/coolchic/cpp/synlb_cpu.hpp
@@ -21,7 +21,7 @@
// dedicated to 3x3 kernels that use a line-buffer for temporary storage, allowing in-place convolution.
void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_in, int plane_stride_in, int residue_origin_offset, int N_IN, int32_t *in, int N_OUT, int32_t *out, int32_t *line_buffer, int residue, int relu)
{
- printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu);
+ //printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu);
int const kstride = 1;
#ifdef SYN_KS
@@ -40,12 +40,17 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
int const n_out = N_OUT;
#endif
- int32_t *in_layer[std::max(n_in, n_out)];
+ int32_t *in_layer[n_in];
in_layer[0] = in;
- for (int i = 1; i < std::max(n_in, n_out); i++)
+ for (int i = 1; i < n_in; i++)
in_layer[i] = in_layer[i-1]+plane_stride_in;
+
+ int32_t *out_layer[n_out];
+ out_layer[0] = out;
+ for (int i = 1; i < n_out; i++)
+ out_layer[i] = out_layer[i-1]+plane_stride_in;
+
// in-place, must have line buffer.
- int32_t **out_layer;
int32_t *lb[2]; // two line buffer pointers.
int h_out = h_in-ks+1;
int w_out = w_in-ks+1;
@@ -54,23 +59,14 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
printf("%s: bad call: in-place lb must have ks=3\n", xstr(SYN_NAME));
exit(1);
}
- if (out == NULL || out == in)
- {
- // must have a line buffer.
- if (line_buffer == NULL)
- {
- printf("%s: bad call, no line buffer supplied\n", xstr(SYN_NAME));
- exit(1);
- }
- out_layer = in_layer;
- lb[0] = line_buffer;
- lb[1] = line_buffer+w_out*n_out;
- }
- else
+ // must have a line buffer.
+ if (line_buffer == NULL)
{
- printf("%s: bad call should have lb and in-place\n", xstr(SYN_NAME));
+ printf("%s: bad call, no line buffer supplied\n", xstr(SYN_NAME));
exit(1);
}
+ lb[0] = line_buffer;
+ lb[1] = line_buffer+w_out*n_out;
// here we collect the output during processing, and flush later.
@@ -102,9 +98,16 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
sum += xxres;
}
}
- sum >>= SYN_MUL_PRECISION; // take multiplied sum to output. // !!! check sign?
- if (relu && sum < 0)
- sum = 0;
+ // take multiplied sum to output after reluing.
+ if (sum < 0)
+ {
+ if (relu)
+ sum = 0;
+ else
+ sum = -(-sum >> SYN_MUL_PRECISION);
+ }
+ else
+ sum >>= SYN_MUL_PRECISION;
lb[y%2][ol*w_out+x] = sum;
}
}
@@ -112,12 +115,12 @@ void SYN_NAME(int KS, int32_t *kw, int32_t *kb, int h_in, int w_in, int stride_i
if (y >= 1)
{
for (int ol = 0; ol < n_out; ol++)
- memcpy(&out_layer[ol][offs0-stride_in+residue_origin_offset], &lb[(y-1)%2][ol*w_out], w_out*sizeof(int32_t));
+ memcpy(&out_layer[ol][offs0-stride_in], &lb[(y-1)%2][ol*w_out], w_out*sizeof(int32_t));
}
}
// flush final line.
for (int ol = 0; ol < n_out; ol++)
- memcpy(&out_layer[ol][offs0-stride_in+residue_origin_offset], &lb[(h_out-1)%2][ol*w_out], w_out*sizeof(int32_t));
+ memcpy(&out_layer[ol][offs0-stride_in], &lb[(h_out-1)%2][ol*w_out], w_out*sizeof(int32_t));
}
#undef tostr
diff --git a/coolchic/cpp/ups_avx2.cpp b/coolchic/cpp/ups_avx2.cpp
index 9f98b2c5..5646300f 100644
--- a/coolchic/cpp/ups_avx2.cpp
+++ b/coolchic/cpp/ups_avx2.cpp
@@ -1,35 +1,25 @@
-/*
- Software Name: Cool-Chic
- SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
- SPDX-License-Identifier: BSD 3-Clause "New"
- This software is distributed under the BSD-3-Clause license.
- Authors: see CONTRIBUTORS.md
-*/
-
-
-#include
-#include
-#include
-#include
#include "common.h"
+#include "frame-memory.h"
+#include "ups_cpu.h"
+#include
+
+#define KS 7
+#define UPSNAME ups_refine_ks7_avx2
+#include "ups_refine_avx2.hpp"
-#define SYN_NAME ups_4x4x4_fromups_avx2
-#define SYN_KS 4
-#define UPS_MUL_PRECISION UPS_PRECISION
-#include "ups_avx2.hpp"
+#define UPSNAME ups_refine_ksX_avx2
+#include "ups_refine_avx2.hpp"
-#define SYN_NAME ups_4x2x2_fromups_avx2
-#define SYN_KS 2
-#define UPS_MUL_PRECISION UPS_PRECISION
-#include "ups_avx2.hpp"
+#define KS 8
+#define UPS_SRC_PRECISION ARM_PRECISION
+#define UPSNAME ups_upsample_ks8_ARMPREC_avx2
+#include "ups_upsample_avx2.hpp"
-#define SYN_NAME ups_4x4x4_fromarm_avx2
-#define SYN_KS 4
-#define UPS_MUL_PRECISION ARM_PRECISION
-#include "ups_avx2.hpp"
+#define KS 8
+#define UPS_SRC_PRECISION UPS_PRECISION
+#define UPSNAME ups_upsample_ks8_UPSPREC_avx2
+#include "ups_upsample_avx2.hpp"
-#define SYN_NAME ups_4x2x2_fromarm_avx2
-#define SYN_KS 2
-#define UPS_MUL_PRECISION ARM_PRECISION
-#include "ups_avx2.hpp"
+#define UPSNAME ups_upsample_ksX_avx2
+#include "ups_upsample_avx2.hpp"
diff --git a/coolchic/cpp/ups_avx2.h b/coolchic/cpp/ups_avx2.h
index 87d29593..6d9a656e 100644
--- a/coolchic/cpp/ups_avx2.h
+++ b/coolchic/cpp/ups_avx2.h
@@ -8,7 +8,8 @@
*/
-void ups_4x4x4_fromarm_avx2(int KS, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int stride_out, int32_t *out);
-void ups_4x4x4_fromups_avx2(int KS, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int stride_out, int32_t *out);
-void ups_4x2x2_fromarm_avx2(int KS, int32_t *kw, int h_in, int w_in, int stride_int, int32_t *in, int stride_out, int32_t *out);
-void ups_4x2x2_fromups_avx2(int KS, int32_t *kw, int h_in, int w_in, int stride_int, int32_t *in, int stride_out, int32_t *out);
+void ups_refine_ks7_avx2(int ks, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp);
+void ups_refine_ksX_avx2(int ks, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp);
+void ups_upsample_ks8_UPSPREC_avx2(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp);
+void ups_upsample_ks8_ARMPREC_avx2(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp);
+void ups_upsample_ksX_avx2(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp);
diff --git a/coolchic/cpp/ups_avx2.hpp b/coolchic/cpp/ups_avx2.hpp
deleted file mode 100644
index 057f3404..00000000
--- a/coolchic/cpp/ups_avx2.hpp
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- Software Name: Cool-Chic
- SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
- SPDX-License-Identifier: BSD 3-Clause "New"
-
- This software is distributed under the BSD-3-Clause license.
- Authors: see CONTRIBUTORS.md
-*/
-
-
-#define tostr(x) #x
-#define xstr(x) tostr(x)
-
-// kw: weights for 4 4x4 kernels.
-void SYN_NAME(int KS, int32_t *kw, int h_in, int w_in, int stride_in, int32_t *in, int stride_out, int32_t *out)
-{
- //printf("%s(ks=%d N_IN=%d N_OUT=%d, residue=%d, relu=%d\n", xstr(SYN_NAME), KS, N_IN, N_OUT, residue, relu);
- int const kstride = 1;
-#ifdef SYN_KS
- int const ks = SYN_KS;
-#else
- int const ks = KS;
-#endif
- int const n_out = 4; // number of kernels, actually, always 4.
-
- if (ks != KS)
- {
- printf("%s: ks=%d: bad call\n", xstr(SYN_NAME), KS);
- exit(1);
- }
-
- int const ks2 = ks*ks;
-
- // we generate 8 outputs at a time, per kernel.
- // these outputs are emitted onto two output lines.
- // we advance 8 pixels in the input after that.
- int offsi_base = 0;
- int offso_base = 0;
- // the last block in x might not be full.
-
- int xlim_blk = ((w_in-ks+1)+7)/8; // pixels in number of horizontal blocks
- int xlast_blk_size = 8-(xlim_blk*8 - (w_in-ks+1)); // number of pixels to emit per filter in last block.
- for (int y = 0; y < h_in-ks+1; y += kstride, offsi_base += stride_in, offso_base += 2*stride_out)
- {
- int offsi_line = offsi_base;
- int offso = offso_base;
- for (int x_blk = 0; x_blk < xlim_blk; x_blk += kstride, offsi_line += kstride*8)
- {
- __m256i out_avx2[n_out];
- for (int ol = 0; ol < n_out; ol++)
- {
- out_avx2[ol] = _mm256_setzero_si256();
- }
- int offsi_block = offsi_line;
- int koffs = 0;
- for (int yy = 0; yy < ks; yy++, offsi_block += stride_in-ks)
- {
- for (int xx = 0; xx < ks; xx++, offsi_block++, koffs++)
- {
- __m256i input = _mm256_loadu_si256((__m256i_u*)&in[offsi_block]);
- for (int ol = 0; ol < n_out; ol++)
- {
- __m256i kk = _mm256_set1_epi32(kw[ol*ks2+koffs]);
- __m256i mul = _mm256_mullo_epi32(input, kk);
- out_avx2[ol] = _mm256_add_epi32(out_avx2[ol], mul);
- }
- }
- }
- // merge the outputs into the destination!
- // We have:
- // ol0 A0 A1 A2 A3.. A7
- // ol1 B0 B1 B2 B3.. B7
- // ol2 C0 C1 C2 C3.. C7
- // ol3 D0 D1 D2 D3.. D7
- //
- // to go to 2 lines as:
- // A0 B0 A1 B1 A2 B2.. A7 B7
- // C0 D0 C1 D1 C2 D2.. C7 D7
-
- out_avx2[0] = _mm256_srai_epi32(out_avx2[0], UPS_MUL_PRECISION);
- out_avx2[1] = _mm256_srai_epi32(out_avx2[1], UPS_MUL_PRECISION);
- out_avx2[2] = _mm256_srai_epi32(out_avx2[2], UPS_MUL_PRECISION);
- out_avx2[3] = _mm256_srai_epi32(out_avx2[3], UPS_MUL_PRECISION);
-
- // spread horizontally and over two lines.
- //float outputs[4][8];
- __m256i outA0 = _mm256_unpacklo_epi32(out_avx2[0], out_avx2[1]);
- __m256i outA1 = _mm256_unpackhi_epi32(out_avx2[0], out_avx2[1]);
- __m256i outC0 = _mm256_unpacklo_epi32(out_avx2[2], out_avx2[3]);
- __m256i outC1 = _mm256_unpackhi_epi32(out_avx2[2], out_avx2[3]);
-
- // We need to remangle A0,A1 and C0,C1, switching their high and low 128-bit halves.
- __m256i outA00 = _mm256_permute2f128_si256(outA0, outA1, 0x20);
- __m256i outA01 = _mm256_permute2f128_si256(outA0, outA1, 0x31);
- __m256i outC00 = _mm256_permute2f128_si256(outC0, outC1, 0x20);
- __m256i outC01 = _mm256_permute2f128_si256(outC0, outC1, 0x31);
- if (x_blk < xlim_blk-1)
- {
- _mm256_storeu_si256((__m256i_u*)&out[offso+0], outA00);
- _mm256_storeu_si256((__m256i_u*)&out[offso+8], outA01);
- _mm256_storeu_si256((__m256i_u*)&out[offso+stride_out+0], outC00);
- _mm256_storeu_si256((__m256i_u*)&out[offso+stride_out+8], outC01);
- offso += 2*8;
- }
- else
- {
- int n_outs = 2*xlast_blk_size; // note -- doubled.
- int32_t tmp[8];
- for (int line = 0; line < 2; line++)
- {
- if (n_outs >= 8)
- {
- _mm256_storeu_si256((__m256i_u*)&out[offso+line*stride_out+0], outA00);
- if (n_outs >= 16)
- _mm256_storeu_si256((__m256i_u*)&out[offso+line*stride_out+8], outA01);
- else if (n_outs > 8)
- {
- // remainder over 8.
- _mm256_storeu_si256((__m256i_u*)&tmp[0], outA01);
- memcpy(&out[offso+line*stride_out+8], &tmp[0], (n_outs-8)*sizeof(tmp[0]));
- }
- }
- else
- {
- // remainder over 0.
- _mm256_storeu_si256((__m256i_u*)&tmp[0], outA00);
- memcpy(&out[offso+line*stride_out+0], &tmp[0], n_outs*sizeof(tmp[0]));
- }
- outA00 = outC00;
- outA01 = outC01;
- }
- offso += n_outs;
- }
- }
- }
-}
-
-#undef tostr
-#undef xstr
-#undef SYN_NAME
-#undef SYN_KS
-#undef UPS_MUL_PRECISION
diff --git a/coolchic/cpp/ups_cpu.cpp b/coolchic/cpp/ups_cpu.cpp
new file mode 100644
index 00000000..df012827
--- /dev/null
+++ b/coolchic/cpp/ups_cpu.cpp
@@ -0,0 +1,18 @@
+
+#include "common.h"
+#include "frame-memory.h"
+#include "ups_cpu.h"
+
+#define KS 7
+#define UPSNAME ups_refine_ks7_cpu
+#include "ups_refine_cpu.hpp"
+
+#define UPSNAME ups_refine_ksX_cpu
+#include "ups_refine_cpu.hpp"
+
+#define KS 8
+#define UPSNAME ups_upsample_ks8_cpu
+#include "ups_upsample_cpu.hpp"
+
+#define UPSNAME ups_upsample_ksX_cpu
+#include "ups_upsample_cpu.hpp"
diff --git a/coolchic/cpp/ups_cpu.h b/coolchic/cpp/ups_cpu.h
new file mode 100644
index 00000000..dcbc6dc0
--- /dev/null
+++ b/coolchic/cpp/ups_cpu.h
@@ -0,0 +1,7 @@
+
+
+void ups_refine_ks7_cpu(int ks, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp);
+void ups_refine_ksX_cpu(int ks, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp);
+
+void ups_upsample_ks8_cpu(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp);
+void ups_upsample_ksX_cpu(int ksx2, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp);
diff --git a/coolchic/cpp/ups_refine_avx2.hpp b/coolchic/cpp/ups_refine_avx2.hpp
new file mode 100644
index 00000000..ed75da04
--- /dev/null
+++ b/coolchic/cpp/ups_refine_avx2.hpp
@@ -0,0 +1,135 @@
+
+// defines expected:
+// KS
+// UPSNAME
+
+#define tostr(x) #x
+#define xstr(x) tostr(x)
+
+// ups_src_precision is always ARM_PRECISION (from arm src)
+void UPSNAME(int ks_param, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp)
+{
+#ifdef KS
+ int const ks = KS;
+ if (ks != ks_param)
+ {
+ printf("%s: bad call, ks_param %d\n", xstr(UPSNAME), ks_param);
+ exit(1);
+ }
+
+ // kernel in registers
+ __m256i KW[KS];
+ for (int i = 0; i < KS; i++)
+ KW[i] = _mm256_set1_epi32(kw[i]);
+#else
+ int const ks = ks_param;
+#endif
+ // temporary check.
+ if (ups_src_precision != ARM_PRECISION)
+ {
+ printf("%s: bad call: input precision is %d, not %d\n", xstr(UPSNAME), ups_src_precision, ARM_PRECISION);
+ exit(1);
+ }
+ int const UPS_SRC_PRECISION = ARM_PRECISION;
+ int kshalf = ks/2;
+ int pad = kshalf;
+ // prepare output.
+ out.update_to(in.h, in.w, out.pad, 1);
+ // pad input.
+ in.custom_pad_zero_plane_in_place_i(0, pad, 0); // LR pad.
+
+ // prepare temporary to hold h x w
+ tmp.update_to(in.h, in.w, tmp.pad, tmp.planes);
+ int32_t *src = in.pad_origin(0, pad, 0);
+ int32_t *dst = tmp.origin();
+ const __m256i z = _mm256_setzero_si256();
+
+ // h (horizontal treatment) to temporary.
+ int xlim_blk = (in.w+7)/8; // pixels in number of horizontal blocks
+ int xlast_blk_size = 8-(xlim_blk*8 - in.w); // number of pixels to emit per filter in last block.
+ int32_t store[8];
+
+ for (int y = 0; y < in.h; y++, src += in.stride-xlim_blk*8, dst += tmp.stride-xlim_blk*8)
+ {
+ for (int x_blk = 0; x_blk < xlim_blk; x_blk++, src += 8, dst += 8)
+ {
+ int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size;
+ __m256i out_avx2 = _mm256_setzero_si256();
+
+ for (int xx = 0; xx < ks; xx++)
+ {
+ __m256i input = _mm256_loadu_si256((__m256i_u*)&src[xx]);
+#ifdef KS
+ __m256i mul = _mm256_mullo_epi32(input, KW[xx]);
+#else
+ __m256i kk = _mm256_set1_epi32(kw[xx]);
+ __m256i mul = _mm256_mullo_epi32(input, kk);
+#endif
+ out_avx2 = _mm256_add_epi32(out_avx2, mul);
+ }
+ // need different treatment for -ve -(-(>>)) and +ve (>>).
+ __m256i sr = _mm256_srai_epi32(out_avx2, ARM_PRECISION);
+ __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_avx2), ARM_PRECISION));
+ out_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_avx2, z));
+
+ if (n_outs == 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)dst, out_avx2);
+ }
+ else
+ {
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2);
+ memcpy(dst, &store[0], n_outs*sizeof(store[0]));
+ }
+ }
+ }
+
+ // v (vertical treatment) to output.
+ tmp.custom_pad_zero_plane_in_place_i(0, 0, pad); // TB pad.
+ int32_t *hsrc = tmp.pad_origin(0, 0, pad);
+ src = in.origin();
+ dst = out.origin();
+ int residue_shift = UPS_PRECISION-UPS_SRC_PRECISION; // src is arm output at lower precision.
+
+ for (int y = 0; y < in.h; y++, hsrc += tmp.stride-xlim_blk*8, src += in.stride-xlim_blk*8, dst += out.stride-xlim_blk*8)
+ {
+ for (int x_blk = 0; x_blk < xlim_blk; x_blk++, hsrc += 8, src += 8, dst += 8)
+ {
+ int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size;
+ __m256i out_avx2 = _mm256_loadu_si256((__m256i_u*)&src[0]);
+ out_avx2 = _mm256_slli_epi32(out_avx2, residue_shift+UPS_PRECISION);
+
+ for (int yy = 0; yy < ks; yy++)
+ {
+ __m256i input = _mm256_loadu_si256((__m256i_u*)&hsrc[yy*tmp.stride]);
+#ifdef KS
+ __m256i mul = _mm256_mullo_epi32(input, KW[yy]);
+#else
+ __m256i kk = _mm256_set1_epi32(kw[yy]);
+ __m256i mul = _mm256_mullo_epi32(input, kk);
+#endif
+ out_avx2 = _mm256_add_epi32(out_avx2, mul);
+ }
+
+ // need different treatment for -ve -(-(>>)) and +ve (>>).
+ __m256i sr = _mm256_srai_epi32(out_avx2, UPS_PRECISION);
+ __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_avx2), UPS_PRECISION));
+ out_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_avx2, z));
+
+ if (n_outs == 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)dst, out_avx2);
+ }
+ else
+ {
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_avx2);
+ memcpy(dst, &store[0], n_outs*sizeof(store[0]));
+ }
+ }
+ }
+}
+
+#undef KS
+#undef UPSNAME
+#undef tostr
+#undef xstr
diff --git a/coolchic/cpp/ups_refine_cpu.hpp b/coolchic/cpp/ups_refine_cpu.hpp
new file mode 100644
index 00000000..47c0d4b8
--- /dev/null
+++ b/coolchic/cpp/ups_refine_cpu.hpp
@@ -0,0 +1,84 @@
+
+// defines expected:
+// KS
+// UPSNAME
+
+#define tostr(x) #x
+#define xstr(x) tostr(x)
+
+// ups_src_precision is always ARM_PRECISION (from arm src)
+// tmp is guaranteed to hold a horizontally upsampled in.
+void UPSNAME(int ks_param, int32_t *kw, frame_memory &in, frame_memory &out, int ups_src_precision, frame_memory &tmp)
+{
+#ifdef KS
+ int const ks = KS;
+ if (ks != ks_param)
+ {
+ printf("%s: bad call, ks_param %d\n", xstr(UPSNAME), ks_param);
+ exit(1);
+ }
+#else
+ int const ks = ks_param;
+#endif
+ int kshalf = ks/2;
+ int pad = kshalf;
+ // prepare output.
+ out.update_to(in.h, in.w, out.pad, 1);
+ // pad input.
+ in.custom_pad_zero_plane_in_place_i(0, pad, 0); // LR pad.
+
+ // prepare temporary to hold h x w
+ tmp.update_to(in.h, in.w, tmp.pad, tmp.planes);
+ int32_t *src = in.pad_origin(0, pad, 0);
+ int32_t *dst = tmp.origin();
+
+ // h (horizontal treatment) to temporary.
+ for (int y = 0; y < in.h; y++, src += in.stride-in.w, dst += tmp.stride-in.w)
+ {
+ for (int x = 0; x < in.w; x++, src++, dst++)
+ {
+#if 1
+ // ignore symmetric.
+ int sum = 0;
+ for (int xx = 0; xx < ks; xx++)
+ sum += src[xx]*kw[xx];
+#endif
+#if 0
+ // use symmetric.
+ int sum = src[kshalf]*kw[kshalf];
+ for (int xx = 0; xx < kshalf; xx++)
+ sum += (src[xx]+src[ks-xx-1])*kw[xx];
+#endif
+ if (sum < 0)
+ *dst = -(-sum >> ups_src_precision);
+ else
+ *dst = sum >> ups_src_precision;
+ }
+ }
+
+ // v (vertical treatment) to output.
+ tmp.custom_pad_zero_plane_in_place_i(0, 0, pad); // TB pad.
+ int32_t *hsrc = tmp.pad_origin(0, 0, pad);
+ src = in.origin();
+ dst = out.origin();
+ int residue_shift = UPS_PRECISION-ups_src_precision; // src is arm output at lower precision.
+ for (int y = 0; y < in.h; y++, hsrc += tmp.stride-in.w, dst += out.stride-in.w, src += in.stride-in.w)
+ {
+ for (int x = 0; x < in.w; x++, hsrc++, dst++, src++)
+ {
+ int sum = 0;
+ for (int yy = 0; yy < ks; yy++)
+ sum += hsrc[yy*tmp.stride]*kw[yy];
+ sum += (*src<> UPS_PRECISION);
+ else
+ *dst = sum >> UPS_PRECISION;
+ }
+ }
+}
+
+#undef KS
+#undef UPSNAME
+#undef tostr
+#undef xstr
diff --git a/coolchic/cpp/ups_upsample_avx2.hpp b/coolchic/cpp/ups_upsample_avx2.hpp
new file mode 100644
index 00000000..3171b8df
--- /dev/null
+++ b/coolchic/cpp/ups_upsample_avx2.hpp
@@ -0,0 +1,211 @@
+
+// defines expected:
+// KS
+// optional UPS_SRC_PRECISION
+// UPSNAME
+
+
+#define tostr(x) #x
+#define xstr(x) tostr(x)
+
+// ups_src_precision is either ARM_PRECISION (from arm src, unrefined lowest layer) or UPS_PRECISION (from ups refinement src)
+// incoming kernel is actually two filters, interleavead, to produce two outputs.
+// tmp is guaranteed to hold a horizontally upsampled in.
+void UPSNAME(int ksx2_param, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp)
+{
+#ifdef KS
+ int const ksx2 = KS;
+ if (ksx2 != ksx2_param)
+ {
+ printf("%s: bad call, ksx2_param %d\n", xstr(UPSNAME), ksx2_param);
+ exit(1);
+ }
+#else
+ int const ksx2 = ksx2_param;
+#endif
+ int const ks = ksx2/2; // 2 kernels of size ks.
+ int pad = ks/2; // incoming ks is two even-sized filters.
+ in.custom_pad_replicate_plane_in_place_i(0, pad, 0); // LR pad.
+
+#ifdef UPS_SRC_PRECISION
+ if (ups_src_precision != UPS_SRC_PRECISION)
+ {
+ printf("%s: bad call: input precision is %d, not %d\n", xstr(UPSNAME), ups_src_precision, UPS_SRC_PRECISION);
+ exit(1);
+ }
+#endif
+
+#ifdef KS
+ // kernel in registers
+ __m256i KW_EVEN[KS];
+ __m256i KW_ODD[KS];
+ for (int i = 0; i < KS; i++)
+ {
+ KW_EVEN[i] = _mm256_set1_epi32(kw[i*2]);
+ KW_ODD[i] = _mm256_set1_epi32(kw[i*2+1]);
+ }
+#else
+ int32_t kw_even[ks];
+ int32_t kw_odd[ks];
+ for (int i = 0; i < ks; i++)
+ {
+ kw_even[i] = kw[i*2];
+ kw_odd[i] = kw[i*2+1];
+ }
+#endif
+
+ // prepare temporary to hold h x 2*w
+ tmp.update_to(in.h, 2*in.w, tmp.pad, tmp.planes);
+ int32_t *src = in.pad_origin(0, pad, 0);
+ int32_t *dst = tmp.origin();
+ const __m256i z = _mm256_setzero_si256();
+
+ // h (horizontal scale) to temporary.
+ int xlim_blk = (in.w+7)/8; // pixels in number of horizontal blocks
+ int xlast_blk_size = 8-(xlim_blk*8 - in.w); // number of pixels to emit per filter in last block.
+ int32_t store[8];
+
+ for (int y = 0; y < in.h; y++, src += in.stride-xlim_blk*8, dst += tmp.stride-xlim_blk*8*2)
+ {
+ for (int x_blk = 0; x_blk < xlim_blk; x_blk++, src += 8, dst += 8*2)
+ {
+ int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size;
+ __m256i out_even_avx2 = _mm256_setzero_si256();
+ __m256i out_odd_avx2 = _mm256_setzero_si256();
+ for (int xx = 0; xx < ks; xx++)
+ {
+ __m256i input_even = _mm256_loadu_si256((__m256i_u*)&src[xx]);
+ __m256i input_odd = _mm256_loadu_si256((__m256i_u*)&src[xx+1]); // !!! strange we need a +1 for odd.
+#ifdef KS
+ __m256i mul = _mm256_mullo_epi32(input_even, KW_EVEN[xx]);
+ out_even_avx2 = _mm256_add_epi32(out_even_avx2, mul);
+ mul = _mm256_mullo_epi32(input_odd, KW_ODD[xx]);
+ out_odd_avx2 = _mm256_add_epi32(out_odd_avx2, mul);
+#else
+ __m256i kk = _mm256_set1_epi32(kw_even[xx]);
+ __m256i mul = _mm256_mullo_epi32(input_even, kk);
+ out_even_avx2 = _mm256_add_epi32(out_even_avx2, mul);
+ kk = _mm256_set1_epi32(kw_odd[xx]);
+ mul = _mm256_mullo_epi32(input_odd, kk);
+ out_odd_avx2 = _mm256_add_epi32(out_odd_avx2, mul);
+#endif
+ }
+ // need different treatment for -ve -(-(>>)) and +ve (>>).
+#ifdef UPS_SRC_PRECISION
+ __m256i sr = _mm256_srai_epi32(out_even_avx2, UPS_SRC_PRECISION);
+ __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_even_avx2), UPS_SRC_PRECISION));
+ out_even_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_even_avx2, z));
+
+ sr = _mm256_srai_epi32(out_odd_avx2, UPS_SRC_PRECISION);
+ negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_odd_avx2), UPS_SRC_PRECISION));
+ out_odd_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_odd_avx2, z));
+#else
+ __m256i sr = _mm256_srai_epi32(out_even_avx2, ups_src_precision);
+ __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_even_avx2), ups_src_precision));
+ out_even_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_even_avx2, z));
+
+ sr = _mm256_srai_epi32(out_odd_avx2, ups_src_precision);
+ negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_odd_avx2), ups_src_precision));
+ out_odd_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_odd_avx2, z));
+#endif
+
+ // spread horizontally
+ __m256i outA0 = _mm256_unpacklo_epi32(out_even_avx2, out_odd_avx2);
+ __m256i outA1 = _mm256_unpackhi_epi32(out_even_avx2, out_odd_avx2);
+ // We need to remangle A0,A1 switching their high and low 128-bit halves.
+ out_even_avx2 = _mm256_permute2f128_si256(outA0, outA1, 0x20);
+ out_odd_avx2 = _mm256_permute2f128_si256(outA0, outA1, 0x31);
+ if (n_outs == 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)(dst+0), out_even_avx2);
+ _mm256_storeu_si256((__m256i_u*)(dst+8), out_odd_avx2);
+ }
+ else
+ {
+ n_outs *= 2;
+ if (n_outs >= 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)(dst+0), out_even_avx2);
+ if (n_outs >= 16)
+ _mm256_storeu_si256((__m256i_u*)(dst+8), out_odd_avx2);
+ else if (n_outs > 8)
+ {
+ // remainder over 8.
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_odd_avx2);
+ memcpy(&dst[8], &store[0], (n_outs-8)*sizeof(store[0]));
+ }
+ }
+ else
+ {
+ // remainder over 0.
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_even_avx2);
+ memcpy(&dst[0], &store[0], n_outs*sizeof(store[0]));
+ }
+ }
+ }
+ }
+
+ // v (vertical) to output.
+ tmp.custom_pad_replicate_plane_in_place_i(0, 0, pad); // TB pad.
+ int32_t *hsrc = tmp.pad_origin(0, 0, pad);
+ dst = out.plane_origin(out_plane);
+
+ xlim_blk = (tmp.w+7)/8; // pixels in number of horizontal blocks
+ xlast_blk_size = 8-(xlim_blk*8 - tmp.w); // number of pixels to emit per filter in last block.
+
+ for (int y = 0; y < out.h; y +=2, hsrc += tmp.stride-xlim_blk*8, dst += out.stride-xlim_blk*8+out.stride)
+ {
+ for (int x_blk = 0; x_blk < xlim_blk; x_blk++, hsrc += 8, dst += 8)
+ {
+ int n_outs = (x_blk < xlim_blk-1) ? 8 : xlast_blk_size;
+ __m256i out_even_avx2 = _mm256_setzero_si256();
+ __m256i out_odd_avx2 = _mm256_setzero_si256();
+ for (int yy = 0; yy < ks; yy++)
+ {
+ __m256i input_even = _mm256_loadu_si256((__m256i_u*)&hsrc[yy*tmp.stride]);
+ __m256i input_odd = _mm256_loadu_si256((__m256i_u*)&hsrc[(yy+1)*tmp.stride]); // !!! strange we need a +1 for odd.
+#ifdef KS
+ __m256i mul = _mm256_mullo_epi32(input_even, KW_EVEN[yy]);
+ out_even_avx2 = _mm256_add_epi32(out_even_avx2, mul);
+ mul = _mm256_mullo_epi32(input_odd, KW_ODD[yy]);
+ out_odd_avx2 = _mm256_add_epi32(out_odd_avx2, mul);
+#else
+ __m256i kk = _mm256_set1_epi32(kw_even[yy]);
+ __m256i mul = _mm256_mullo_epi32(input_even, kk);
+ out_even_avx2 = _mm256_add_epi32(out_even_avx2, mul);
+ kk = _mm256_set1_epi32(kw_odd[yy]);
+ mul = _mm256_mullo_epi32(input_odd, kk);
+ out_odd_avx2 = _mm256_add_epi32(out_odd_avx2, mul);
+#endif
+ }
+ // need different treatment for -ve -(-(>>)) and +ve (>>).
+ __m256i sr = _mm256_srai_epi32(out_even_avx2, UPS_PRECISION);
+ __m256i negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_even_avx2), UPS_PRECISION));
+ out_even_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_even_avx2, z));
+
+ sr = _mm256_srai_epi32(out_odd_avx2, UPS_PRECISION);
+ negsrneg = _mm256_sub_epi32(z, _mm256_srai_epi32(_mm256_sub_epi32(z, out_odd_avx2), UPS_PRECISION));
+ out_odd_avx2 = _mm256_blendv_epi8(negsrneg, sr, _mm256_cmpgt_epi32(out_odd_avx2, z));
+
+ if (n_outs == 8)
+ {
+ _mm256_storeu_si256((__m256i_u*)(dst+0*out.stride), out_even_avx2);
+ _mm256_storeu_si256((__m256i_u*)(dst+1*out.stride), out_odd_avx2);
+ }
+ else
+ {
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_even_avx2);
+ memcpy(&dst[0], &store[0], n_outs*sizeof(store[0]));
+ _mm256_storeu_si256((__m256i_u*)&store[0], out_odd_avx2);
+ memcpy(&dst[out.stride], &store[0], n_outs*sizeof(store[0]));
+ }
+ }
+ }
+}
+
+
+#undef KS
+#undef UPS_SRC_PRECISION
+#undef UPSNAME
+#undef tostr
+#undef xstr
diff --git a/coolchic/cpp/ups_upsample_cpu.hpp b/coolchic/cpp/ups_upsample_cpu.hpp
new file mode 100644
index 00000000..bb404604
--- /dev/null
+++ b/coolchic/cpp/ups_upsample_cpu.hpp
@@ -0,0 +1,97 @@
+
+// defines expected:
+// KS
+// UPSNAME
+
+#define tostr(x) #x
+#define xstr(x) tostr(x)
+
+// ups_src_precision is either ARM_PRECISION (from arm src, unrefined lowest layer) or UPS_PRECISION (from ups refinement src)
+// incoming kernel is actually two filters, interleavead, to produce two outputs.
+// tmp is guaranteed to hold a horizontally upsampled in.
+void UPSNAME(int ksx2_param, int32_t *kw, frame_memory &in, frame_memory &out, int out_plane, int ups_src_precision, frame_memory &tmp)
+{
+#ifdef KS
+ int const ksx2 = KS;
+ if (ksx2 != ksx2_param)
+ {
+ printf("%s: bad call, ksx2_param %d\n", xstr(UPSNAME), ksx2_param);
+ exit(1);
+ }
+#else
+ int const ksx2 = ksx2_param;
+#endif
+ int const ks = ksx2/2; // 2 kernels of size ks.
+ int pad = ks/2; // incoming ks is two even-sized filters.
+ in.custom_pad_replicate_plane_in_place_i(0, pad, 0); // LR pad.
+
+ int32_t kw_even[ks];
+ int32_t kw_odd[ks];
+ for (int i = 0; i < ks; i++)
+ {
+ kw_even[i] = kw[i*2];
+ kw_odd[i] = kw[i*2+1];
+ }
+
+ // prepare temporary to hold h x 2*w
+ tmp.update_to(in.h, 2*in.w, tmp.pad, tmp.planes);
+ int32_t *src = in.pad_origin(0, pad, 0);
+ int32_t *dst = tmp.origin();
+
+ // h (horizontal scale) to temporary.
+ for (int y = 0; y < in.h; y++, src += in.stride-in.w, dst += tmp.stride-2*in.w)
+ {
+ for (int x = 0; x < in.w; x++)
+ {
+ // even
+ int sum_even = 0;
+ int sum_odd = 0;
+ for (int xx = 0; xx < ks; xx++)
+ {
+ sum_even += src[xx]*kw_even[xx];
+ sum_odd += src[xx+1]*kw_odd[xx]; // !!! strange we need a +1 for odd.
+ }
+ if (sum_even < 0)
+ *dst++ = -(-sum_even >> ups_src_precision);
+ else
+ *dst++ = sum_even >> ups_src_precision;
+ if (sum_odd < 0)
+ *dst++ = -(-sum_odd >> ups_src_precision);
+ else
+ *dst++ = sum_odd >> ups_src_precision;
+ src++;
+ }
+ }
+
+ // v (vertical) to output.
+ tmp.custom_pad_replicate_plane_in_place_i(0, 0, pad); // TB pad.
+ src = tmp.pad_origin(0, 0, pad);
+ dst = out.plane_origin(out_plane);
+ for (int y = 0; y < out.h; y += 2, src += tmp.stride-out.w, dst += out.stride-out.w+out.stride)
+ {
+ for (int x = 0; x < out.w; x++, src++, dst++)
+ {
+ int sum_even = 0;
+ int sum_odd = 0;
+ for (int yy = 0; yy < ks; yy++)
+ {
+ sum_even += src[yy*tmp.stride]*kw_even[yy];
+ sum_odd += src[(yy+1)*tmp.stride]*kw_odd[yy]; // !!! strange we need a +1 for odd.
+ }
+ if (sum_even < 0)
+ dst[0] = -(-sum_even >> UPS_PRECISION);
+ else
+ dst[0] = sum_even >> UPS_PRECISION;
+ if (sum_odd < 0)
+ dst[0+out.stride] = -(-sum_odd >> UPS_PRECISION);
+ else
+ dst[0+out.stride] = sum_odd >> UPS_PRECISION;
+ }
+ }
+}
+
+
+#undef KS
+#undef UPSNAME
+#undef tostr
+#undef xstr
diff --git a/coolchic/dec/nn.py b/coolchic/dec/nn.py
index 1f38f370..2930eee4 100644
--- a/coolchic/dec/nn.py
+++ b/coolchic/dec/nn.py
@@ -7,11 +7,10 @@
# Authors: see CONTRIBUTORS.md
-import numpy as np
import torch
import torch.nn as nn
-from enc.utils.misc import FIXED_POINT_FRACTIONAL_MULT, DescriptorNN
+from enc.utils.misc import DescriptorNN
from CCLIB.ccencapi import cc_decode_wb
@@ -40,7 +39,7 @@ def decode_network(
Returns:
nn.Module: The decoded module
"""
- have_bias = q_step_nn.bias > 0
+ have_bias = bitstream_path.bias != ""
# Instantiate two range coder objects to decode simultaneously weight and bias
bac_ctx_weight = cc_decode_wb(bitstream_path.weight)
@@ -49,12 +48,10 @@ def decode_network(
loaded_param = {}
for k, v in empty_module.named_parameters():
- if k.endswith('.w') or k.endswith('.weight'):
- cur_scale = scale_nn.weight
+ if "weight" in k:
cur_q_step = q_step_nn.weight
cur_param = bac_ctx_weight.decode_wb_continue(len(v.flatten()), scale_nn.weight)
- elif k.endswith('.b') or k.endswith('.bias'):
- cur_scale = scale_nn.bias
+ elif "bias" in k and have_bias:
cur_q_step = q_step_nn.bias
cur_param = bac_ctx_bias.decode_wb_continue(len(v.flatten()), scale_nn.bias)
else:
@@ -68,5 +65,5 @@ def decode_network(
if "arm" in bitstream_path.weight:
empty_module.set_param_from_float(loaded_param)
else:
- empty_module.load_state_dict(loaded_param)
+ empty_module.load_state_dict(loaded_param, strict = have_bias)
return empty_module
diff --git a/coolchic/decode.py b/coolchic/decode.py
index 1e489504..c5fcbd2f 100644
--- a/coolchic/decode.py
+++ b/coolchic/decode.py
@@ -18,9 +18,39 @@
if __name__ == "__main__":
# =========================== Parse arguments =========================== #
parser = argparse.ArgumentParser()
- parser.add_argument( "--input", "-i", type=str, default="./bitstream.cool", help="Bitstream path.")
- parser.add_argument( "--output", "-o", default="", help="output ppm (rgb) or yuv")
- parser.add_argument( "--no_avx2", action='store_true', help="Disable AVX2 support")
+ parser.add_argument(
+ "--input", "-i", type=str, default="./bitstream.cool", help="Bitstream path."
+ )
+ parser.add_argument("--output", "-o", default="", help="output ppm (rgb) or yuv")
+ parser.add_argument("--no_avx2", action="store_true", help="Disable AVX2 usage")
+ parser.add_argument(
+ "--verbosity", type=int, default=0,
+ help=""
+ "0 does not output anything ; "
+ "1 prints the runtime of each step ;"
+ "2 is for debug."
+ )
+ parser.add_argument(
+ "--output_chroma_format",
+ type=int,
+ default=0,
+ help=
+ "Use 0 to infer this from the bitstream header. "
+ " "
+ "Otherwise, specify '420' or '444' to change the chroma sampling for the "
+ "YUV output. "
+ " "
+ " Useless for RGB."
+ )
+ parser.add_argument(
+ "--output_bitdepth",
+ type=int,
+ default=0,
+ help=
+ "Use 0 to infer this from the bitstream header. "
+ " "
+ "Otherwise, specify an integer in [8, 16] to set the output bitdepth."
+ )
args = parser.parse_args()
# =========================== Parse arguments =========================== #
@@ -38,7 +68,21 @@
if use_avx2:
from CCLIB.ccdecapi_avx2 import cc_decode_avx2
- print("Using AVX2 instructions for faster decoding")
- cc_decode_avx2(args.input, args.output)
+
+ if args.verbosity >= 2:
+ print("Using AVX2 instructions for faster decoding")
+ cc_decode_avx2(
+ args.input,
+ args.output,
+ args.output_bitdepth,
+ args.output_chroma_format,
+ args.verbosity,
+ )
else:
- cc_decode_cpu(args.input, args.output)
\ No newline at end of file
+ cc_decode_cpu(
+ args.input,
+ args.output,
+ args.output_bitdepth,
+ args.output_chroma_format,
+ args.verbosity,
+ )
diff --git a/coolchic/enc/bitstream/armint.py b/coolchic/enc/bitstream/armint.py
new file mode 100644
index 00000000..9767145d
--- /dev/null
+++ b/coolchic/enc/bitstream/armint.py
@@ -0,0 +1,277 @@
+# Software Name: Cool-Chic
+# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
+# SPDX-License-Identifier: BSD 3-Clause "New"
+#
+# This software is distributed under the BSD-3-Clause license.
+#
+# Authors: see CONTRIBUTORS.md
+
+
+"""Fixed point implementation of the ARM to avoid floating point drift."""
+
+from typing import OrderedDict, Tuple
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class ArmIntLinear(nn.Module):
+ """Create a Linear layer of the Auto-Regressive Module (ARM). This is a
+ wrapper around the usual ``nn.Linear`` layer of PyTorch, with a custom
+ initialization. It performs the following operations:
+
+ * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b}` if
+ ``residual`` is ``False``
+
+ * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b} +
+ \\mathbf{x}_{in}` if ``residual`` is ``True``.
+
+ The input :math:`\\mathbf{x}_{in}` is a :math:`[B, C_{in}]` tensor, the
+ output :math:`\\mathbf{x}_{out}` is a :math:`[B, C_{out}]` tensor.
+
+ The layer weight and bias shapes are :math:`\\mathbf{W} \\in
+ \\mathbb{R}^{C_{out} \\times C_{in}}` and :math:`\\mathbf{b} \\in
+ \\mathbb{R}^{C_{out}}`.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ fpfm: int = 0,
+ pure_int: bool = False,
+ residual: bool = False,
+ ):
+ """
+ Args:
+ in_channels: Number of input features :math:`C_{in}`.
+ out_channels: Number of output features :math:`C_{out}`.
+ fpfm: Internal stuff for integer computation. **No need to modify
+ this**. Defaults to 0.
+ residual: True to add a residual connexion to the layer. Defaults to
+ False.
+ """
+
+ super().__init__()
+
+ self.fpfm = fpfm
+ self.pure_int = pure_int
+ self.residual = residual
+
+ # -------- Instantiate empty parameters, set by a later load
+ if self.pure_int:
+ self.weight = nn.Parameter(
+ torch.empty((out_channels, in_channels), dtype=torch.int32),
+ requires_grad=False,
+ )
+ self.bias = nn.Parameter(
+ torch.empty((out_channels), dtype=torch.int32), requires_grad=False
+ )
+ else:
+ self.weight = nn.Parameter(
+ torch.empty((out_channels, in_channels), dtype=torch.float),
+ requires_grad=False,
+ )
+ self.bias = nn.Parameter(
+ torch.empty((out_channels), dtype=torch.float), requires_grad=False
+ )
+ # -------- Instantiate empty parameters, set by a later load
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Perform the forward pass of this layer.
+
+ Args:
+ x: Input tensor of shape :math:`[B, C_{in}]`.
+
+ Returns:
+ Tensor with shape :math:`[B, C_{out}]`.
+ """
+ if self.residual:
+ xx = F.linear(x, self.weight, bias=self.bias) + x * self.fpfm
+ else:
+ xx = F.linear(x, self.weight, bias=self.bias)
+
+ # Renorm by fpfm after our (x*fpfm)*(qw*fpfm) multiplication.
+ # WE MAKE INTEGER DIVISION OBEY C++ (TO-ZERO) SEMANTICS, NOT PYTHON (TO-NEGATIVE-INFINITY) SEMANTICS
+ if self.pure_int:
+ xx = xx + torch.sign(xx) * self.fpfm // 2
+ # We separate out -ve and non-ve.
+ neg_result = -((-xx) // self.fpfm)
+ pos_result = xx // self.fpfm
+ result = torch.where(xx < 0, neg_result, pos_result)
+ else:
+ xx = xx + torch.sign(xx) * self.fpfm / 2
+ # We separate out -ve and non-ve.
+ neg_result = -((-xx) / self.fpfm)
+ pos_result = xx / self.fpfm
+ result = torch.where(xx < 0, neg_result, pos_result)
+ result = result.to(torch.int32).to(torch.float)
+
+ return result
+
+
+class ArmInt(nn.Module):
+ """Instantiate an autoregressive probability module, modelling the
+ conditional distribution :math:`p_{\\psi}(\\hat{y}_i \\mid
+ \\mathbf{c}_i)` of a (quantized) latent pixel :math:`\\hat{y}_i`,
+ conditioned on neighboring already decoded context pixels
+ :math:`\\mathbf{c}_i \in \\mathbb{Z}^C`, where :math:`C` denotes the
+ number of context pixels.
+
+ The distribution :math:`p_{\\psi}` is assumed to follow a Laplace
+ distribution, parameterized by an expectation :math:`\\mu` and a scale
+ :math:`b`, where the scale and the variance :math:`\\sigma^2` are
+ related as follows :math:`\\sigma^2 = 2 b ^2`.
+
+ The parameters of the Laplace distribution for a given latent pixel
+ :math:`\\hat{y}_i` are obtained by passing its context pixels
+ :math:`\\mathbf{c}_i` through an MLP :math:`f_{\\psi}`:
+
+ .. math::
+
+ p_{\\psi}(\\hat{y}_i \\mid \\mathbf{c}_i) \sim \mathcal{L}(\\mu_i,
+ b_i), \\text{ where } \\mu_i, b_i = f_{\\psi}(\\mathbf{c}_i).
+
+ .. attention::
+
+ The MLP :math:`f_{\\psi}` has a few constraint on its architecture:
+
+ * The width of all hidden layers (i.e. the output of all layers except
+ the final one) are identical to the number of pixel contexts
+ :math:`C`;
+
+ * All layers except the last one are residual layers, followed by a
+ ``ReLU`` non-linearity;
+
+ * :math:`C` must be at a multiple of 8.
+
+ The MLP :math:`f_{\\psi}` is made of custom Linear layers instantiated
+ from the ``ArmLinear`` class.
+ """
+
+ def __init__(
+ self, dim_arm: int, n_hidden_layers_arm: int, fpfm: int, pure_int: bool
+ ):
+ """
+ Args:
+ dim_arm: Number of context pixels AND dimension of all hidden
+ layers :math:`C`.
+ n_hidden_layers_arm: Number of hidden layers. Set it to 0 for
+ a linear ARM.
+ """
+ super().__init__()
+
+ assert dim_arm % 8 == 0, (
+ f"ARM context size and hidden layer dimension must be "
+ f"a multiple of 8. Found {dim_arm}."
+ )
+
+ self.FPFM = fpfm # fixed-point: multiplication to get int.
+ self.pure_int = pure_int # weights and biases are actual int (cpu only), or just int values in floats (gpu friendly).
+
+ # ======================== Construct the MLP ======================== #
+ layers_list = nn.ModuleList()
+
+ # Construct the hidden layer(s)
+ for i in range(n_hidden_layers_arm):
+ layers_list.append(
+ ArmIntLinear(dim_arm, dim_arm, self.FPFM, self.pure_int, residual=True)
+ )
+ layers_list.append(nn.ReLU())
+
+ # Construct the output layer. It always has 2 outputs (mu and scale)
+ layers_list.append(
+ ArmIntLinear(dim_arm, 2, self.FPFM, self.pure_int, residual=False)
+ )
+ self.mlp = nn.Sequential(*layers_list)
+ # ======================== Construct the MLP ======================== #
+
+ def set_param_from_float(self, float_param: OrderedDict[str, Tensor]) -> None:
+ # We take floating point values here, and convert them to ints.
+
+ # floating point params. We convert to fixed-point integer and store them.
+ integerised_param = {}
+ for k in float_param:
+ if "weight" in k:
+ float_v = float_param[k] * self.FPFM
+ else:
+ float_v = float_param[k] * self.FPFM * self.FPFM
+
+ float_v = float_v + torch.sign(float_v) * 0.5
+ neg_result = -(-float_v).to(torch.int32)
+ pos_result = float_v.to(torch.int32)
+ int_v = torch.where(float_v < 0, neg_result, pos_result)
+ if not self.pure_int:
+ int_v = int_v.to(torch.float)
+ integerised_param[k] = nn.parameter.Parameter(int_v, requires_grad=False)
+
+ self.load_state_dict(integerised_param, assign=True)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ """Perform the auto-regressive module (ARM) forward pass. The ARM takes
+ as input a tensor of shape :math:`[B, C]` i.e. :math:`B` contexts with
+ :math:`C` context pixels. ARM outputs :math:`[B, 2]` values correspond
+ to :math:`\\mu, b` for each of the :math:`B` input pixels.
+
+ .. warning::
+
+ Note that the ARM expects input to be flattened i.e. spatial
+ dimensions :math:`H, W` are collapsed into a single batch-like
+ dimension :math:`B = HW`, leading to an input of shape
+ :math:`[B, C]`, gathering the :math:`C` contexts for each of the
+ :math:`B` pixels to model.
+
+ .. note::
+
+ The ARM MLP does not output directly the scale :math:`b`. Denoting
+ :math:`s` the raw output of the MLP, the scale is obtained as
+ follows:
+
+ .. math::
+
+ b = e^{x - 4}
+
+ Args:
+ x: Concatenation of all input contexts
+ :math:`\\mathbf{c}_i`. Tensor of shape :math:`[B, C]`.
+
+ Returns:
+ Concatenation of all Laplace distributions param :math:`\\mu, b`.
+ Tensor of shape :math:([B]). Also return the *log scale*
+ :math:`s` as described above. Tensor of shape :math:`(B)`
+ """
+ xint = x.clone().detach()
+ xint = xint * self.FPFM
+ if self.pure_int:
+ xint = xint.to(torch.int32)
+
+ for idx_l, layer in enumerate(self.mlp.children()):
+ xint = layer(xint)
+
+ # float the result.
+ raw_proba_param = xint / self.FPFM
+
+ mu = raw_proba_param[:, 0]
+ log_scale = raw_proba_param[:, 1]
+
+ # no scale smaller than exp(-4.6) = 1e-2 or bigger than exp(5.01) = 150
+ scale = torch.exp(torch.clamp(log_scale - 4, min=-4.6, max=5.0))
+
+ return mu, scale, log_scale
+
+ def get_param(self) -> OrderedDict[str, Tensor]:
+ """Return **a copy** of the weights and biases inside the module.
+
+ Returns:
+ A copy of all weights & biases in the layers.
+ """
+ # Detach & clone to create a copy
+ return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()})
+
+ def set_param(self, param: OrderedDict[str, Tensor]) -> None:
+ """Replace the current parameters of the module with param.
+
+ Args:
+ param: Parameters to be set.
+ """
+ self.load_state_dict(param)
diff --git a/coolchic/enc/bitstream/encode.py b/coolchic/enc/bitstream/encode.py
index 0ac8b1cd..01dc1511 100644
--- a/coolchic/enc/bitstream/encode.py
+++ b/coolchic/enc/bitstream/encode.py
@@ -6,7 +6,6 @@
#
# Authors: see CONTRIBUTORS.md
-import math
import os
import subprocess
import time
@@ -18,7 +17,7 @@
from enc.bitstream.utils import get_sub_bitstream_path
from enc.bitstream.header import write_frame_header, write_gop_header
from CCLIB.ccencapi import cc_code_latent_layer_bac, cc_code_wb_bac
-from enc.component.core.arm import Arm, ArmInt
+from enc.bitstream.armint import ArmInt
from enc.component.core.synthesis import Synthesis
from enc.component.core.upsampling import Upsampling
from enc.component.frame import FrameEncoder
@@ -54,7 +53,7 @@ def get_ac_max_val_nn(frame_encoder: FrameEncoder) -> int:
# Retrieve all the weights and biases for the ARM MLP
for k, v in module_to_encode.named_parameters():
- if k.endswith('.w') or k.endswith('.weight'):
+ if "weight" in k:
cur_possible_q_step = POSSIBLE_Q_STEP.get(cur_module_name).get("weight")
# Find the index of the closest quantization step in the list of
@@ -74,12 +73,20 @@ def get_ac_max_val_nn(frame_encoder: FrameEncoder) -> int:
torch.round((v/FIXED_POINT_FRACTIONAL_MULT) / cur_possible_q_step[cur_q_step_index]).flatten()
)
else:
+ # No longer relevant without the bi-branch synthesis!
+ # # # Blending -- we get the transformed weight, not the underlying sigmoid parameter.
+ # # # plus: only if >1 branch.
+ # # if cur_module_name == "synthesis" and k.endswith(".parametrizations.weight.original"):
+ # # if "branch_blender" in k and frame_encoder.coolchic_encoder_param.n_synth_branch == 1:
+ # # continue # Do not emit unused blender weight.
+ # # xformed_weights = getattr(module_to_encode, k.replace(".parametrizations.weight.original", "")).weight
+ # # v = xformed_weights
model_param_quant.append(
torch.round(v / cur_possible_q_step[cur_q_step_index]).flatten()
)
- elif k.endswith('.b') or k.endswith('.bias'):
+ elif "bias" in k:
# Find the index of the closest quantization step in the list of
# the possible quantization step.
cur_possible_q_step = POSSIBLE_Q_STEP.get(cur_module_name).get("bias")
@@ -158,9 +165,6 @@ def encode_video(video_encoder: VideoEncoder, bitstream_path: str, hls_sig_blksi
# ======================== GOP HEADER ======================== #
for idx_coding_order in range(video_encoder.coding_structure.get_number_of_frames()):
- frame = video_encoder.coding_structure.get_frame_from_coding_order(idx_coding_order)
- # assert frame.already_encoded, f'Frame {frame.display_order} has not been encoded yet!'
-
# Retrieve the frame encoder corresponding to the frame
frame_encoder, _ = video_encoder.all_frame_encoders.get(str(idx_coding_order))
@@ -208,6 +212,12 @@ def encode_frame(
frame_encoder.set_to_eval()
frame_encoder.to_device('cpu')
+ # upsampling has bias parameters, but we do not use them.
+ have_bias = { "arm": True,
+ "upsampling": False,
+ "synthesis": True,
+ }
+
subprocess.call(f'rm -f {bitstream_path}', shell=True)
# Load the references
@@ -219,7 +229,6 @@ def encode_frame(
# Move to pure-int Arm. Transfer the quantized weights from the fp Arm.
arm_fp_param = frame_encoder.coolchic_encoder.arm.get_param()
- print("recovered arm params", arm_fp_param.keys())
arm_int = ArmInt(
frame_encoder.coolchic_encoder.param.dim_arm,
frame_encoder.coolchic_encoder.param.n_hidden_layers_arm,
@@ -228,7 +237,6 @@ def encode_frame(
)
frame_encoder.coolchic_encoder.arm = arm_int
frame_encoder.coolchic_encoder.arm.set_param_from_float(arm_fp_param)
- print("set armint(pureint) params")
# ================= Encode the MLP into a bitstream file ================ #
ac_max_val_nn = get_ac_max_val_nn(frame_encoder)
@@ -254,7 +262,7 @@ def encode_frame(
Q_STEPS = POSSIBLE_Q_STEP.get(cur_module_name)
- if k.endswith(".w") or k.endswith(".weight"):
+ if "weight" in k:
# Find the index of the closest quantization step in the list of
# the possible quantization step.
cur_possible_q_step = POSSIBLE_Q_STEP.get(cur_module_name).get("weight")
@@ -271,7 +279,6 @@ def encode_frame(
# Quantize the weight with the actual quantization step and add it
# to the list of (quantized) weights
- # print(cur_module_name, k, v)
if cur_module_name == "arm":
# Our weights are stored as fixed point, we use shifts to get the integer values of quantized results.
# Our int vals are int(floatval << FPFBITS)
@@ -283,11 +290,19 @@ def encode_frame(
v = torch.where(v < 0, neg_v, pos_v)
weights.append(v.flatten())
else:
+ # No longer relevant without the bi-branch synth
+ # # # Blending -- we get the transformed weight, not the underlying sigmoid parameter.
+ # # # plus: only if >1 branch.
+ # # if cur_module_name == "synthesis" and k.endswith(".parametrizations.weight.original"):
+ # # if "branch_blender" in k and frame_encoder.coolchic_encoder_param.n_synth_branch == 1:
+ # # continue # Do not emit unused blender weight.
+ # # xformed_weights = getattr(module_to_encode, k.replace(".parametrizations.weight.original", "")).weight
+ # # v = xformed_weights
weights.append(
torch.round(v / cur_possible_q_step[cur_q_step_index]).flatten()
)
- elif k.endswith(".b") or k.endswith(".bias"):
+ elif "bias" in k and have_bias[cur_module_name]:
# Find the index of the closest quantization step in the list of
# the Q_STEPS quantization step.
cur_possible_q_step = POSSIBLE_Q_STEP.get(cur_module_name).get("bias")
@@ -321,14 +336,15 @@ def encode_frame(
# Gather them
weights = torch.cat(weights).flatten()
- have_bias = len(bias) != 0
- if have_bias:
+ if have_bias[cur_module_name]:
bias = torch.cat(bias).flatten()
+ else:
+ q_step_index_nn[cur_module_name]['bias'] = 0 # we actually send this in the header.
# ----------------- Actual entropy coding
# It happens on cpu
weights = weights.cpu()
- if have_bias:
+ if have_bias[cur_module_name]:
bias = bias.cpu()
cur_bitstream_path = f'{bitstream_path}_{cur_module_name}_weight'
@@ -346,7 +362,7 @@ def encode_frame(
n_bytes_nn[cur_module_name]['weight'] = os.path.getsize(cur_bitstream_path)
- if have_bias:
+ if have_bias[cur_module_name]:
cur_bitstream_path = f'{bitstream_path}_{cur_module_name}_bias'
# either code directly (normal), or search for best (backwards compatible).
@@ -362,6 +378,7 @@ def encode_frame(
n_bytes_nn[cur_module_name]['bias'] = os.path.getsize(cur_bitstream_path)
else:
+ scale_index_nn[cur_module_name]['bias'] = 0
n_bytes_nn[cur_module_name]['bias'] = 0
# ================= Encode the MLP into a bitstream file ================ #
@@ -386,18 +403,21 @@ def encode_frame(
)
elif module_name == 'upsampling':
empty_module = Upsampling(
- frame_encoder.coolchic_encoder.param.upsampling_kernel_size,
- frame_encoder.coolchic_encoder.param.static_upsampling_kernel
+ frame_encoder.coolchic_encoder.param.ups_k_size,
+ frame_encoder.coolchic_encoder.param.ups_preconcat_k_size,
+ # frame_encoder.coolchic_encoder.param.n_ups_kernel,
+ frame_encoder.coolchic_encoder.param.latent_n_grids - 1,
+ # frame_encoder.coolchic_encoder.param.n_ups_preconcat_kernel,
+ frame_encoder.coolchic_encoder.param.latent_n_grids - 1,
)
Q_STEPS = POSSIBLE_Q_STEP.get(module_name)
- have_bias = q_step_index_nn[module_name].get('bias') >= 0
loaded_module = decode_network(
empty_module,
DescriptorNN(
weight = f'{bitstream_path}_{module_name}_weight',
- bias = f'{bitstream_path}_{module_name}_bias' if have_bias else "",
+ bias = f'{bitstream_path}_{module_name}_bias' if have_bias[module_name] else "",
),
DescriptorNN (
weight=Q_STEPS["weight"][q_step_index_nn[module_name]["weight"]],
@@ -408,8 +428,7 @@ def encode_frame(
bias=(
scale_index_nn[module_name]["bias"]
)
- if have_bias
- else 0,
+ if have_bias[module_name] else 0,
),
ac_max_val_nn
)
@@ -456,7 +475,6 @@ def encode_frame(
for index_lat_feature in range(c_i):
y_this_ft = current_y[:, index_lat_feature, :, :].flatten().cpu()
mu_this_ft = current_mu[:, index_lat_feature, :, :].flatten().cpu()
- scale_this_ft = current_scale[:, index_lat_feature, :, :].flatten().cpu()
log_scale_this_ft = current_log_scale[:, index_lat_feature, :, :].flatten().cpu()
if y_this_ft.abs().max() == 0:
diff --git a/coolchic/enc/bitstream/header.py b/coolchic/enc/bitstream/header.py
index 63f0137f..0aeb88a9 100644
--- a/coolchic/enc/bitstream/header.py
+++ b/coolchic/enc/bitstream/header.py
@@ -98,7 +98,7 @@
# Quick & dirty copy and paste from utils.coding_structure
_FRAME_DATA_TYPE = ["rgb", "yuv420", "yuv444"]
-_POSSIBLE_BITDEPTH = [8, 10]
+_POSSIBLE_BITDEPTH = [8, 9, 10, 11, 12, 13, 14, 15, 16]
_POSSIBLE_SYNTHESIS_MODE = [k for k in Synthesis.possible_mode]
_POSSIBLE_SYNTHESIS_NON_LINEARITY = [k for k in Synthesis.possible_non_linearity]
@@ -107,7 +107,7 @@ class GopHeader(TypedDict):
n_bytes_header: int # Number of bytes for the header
img_size: Tuple[int, int] # Format: (height, width)
frame_data_type: FRAME_DATA_TYPE # RGB, YUV 4:2:0, YUV 4:4:4
- bitdepth: POSSIBLE_BITDEPTH # 8 or 10
+ bitdepth: POSSIBLE_BITDEPTH # 8 through 16
intra_period: int # See coding_structure.py
p_period: int # See coding_structure.py
@@ -291,8 +291,12 @@ def write_frame_header(
n_bytes_header += (
1 # Context size and Hidden layer dimension ARM, n. hidden layers ARM
)
- n_bytes_header += 1 # Upsampling kernel size, static upsampling kernel flag
- n_bytes_header += 1 # Number hidden layer Synthesis
+
+ n_bytes_header += 1 # (n_ups_kernel << 4)|(ups_k_size)
+ n_bytes_header += 1 # (n_ups_preconcat_kernel << 4)|(ups_preconcat_k_size)
+
+ n_bytes_header += 1 # Number of synthesis branches
+ n_bytes_header += 1 # Number hidden layer Synthesis per branch
# Hidden Synthesis layer out#, kernelsz, mode+nonlinearity
n_bytes_header += 3 * len(frame_encoder.coolchic_encoder_param.layers_synthesis)
n_bytes_header += 1 # Flow gain
@@ -357,17 +361,20 @@ def write_frame_header(
+ frame_encoder.coolchic_encoder_param.n_hidden_layers_arm
).to_bytes(1, byteorder="big", signed=False)
- assert frame_encoder.coolchic_encoder_param.upsampling_kernel_size < 2**7, (
- f"Upsampling"
- f" kernel size should be small than {2 ** 7}. Found"
- f" {frame_encoder.coolchic_encoder_param.upsampling_kernel_size}"
- )
-
byte_to_write += (
- frame_encoder.coolchic_encoder_param.upsampling_kernel_size * 2**1
- + int(frame_encoder.coolchic_encoder_param.static_upsampling_kernel)
+ # (frame_encoder.coolchic_encoder_param.n_ups_kernel<<4)|(frame_encoder.coolchic_encoder_param.ups_k_size)
+ ((frame_encoder.coolchic_encoder_param.latent_n_grids-1)<<4)|(frame_encoder.coolchic_encoder_param.ups_k_size)
+ ).to_bytes(1, byteorder="big", signed=False)
+ byte_to_write += (
+ # (frame_encoder.coolchic_encoder_param.n_ups_preconcat_kernel<<4)|(frame_encoder.coolchic_encoder_param.ups_preconcat_k_size)
+ ((frame_encoder.coolchic_encoder_param.latent_n_grids-1)<<4)|(frame_encoder.coolchic_encoder_param.ups_preconcat_k_size)
).to_bytes(1, byteorder="big", signed=False)
+ # Continue to send this byte for compatibility
+ _dummy_n_synth_branch = 1
+ byte_to_write += (
+ _dummy_n_synth_branch # frame_encoder.coolchic_encoder_param.n_synth_branch
+ ).to_bytes(1, byteorder="big", signed=False)
byte_to_write += len(
frame_encoder.coolchic_encoder_param.layers_synthesis
).to_bytes(1, byteorder="big", signed=False)
@@ -464,161 +471,3 @@ def write_frame_header(
print("expected", n_bytes_header)
print("got", os.path.getsize(header_path))
exit(1)
-
-
-def read_frame_header(bitstream: bytes) -> FrameHeader:
- """Read the first few bytes of a bitstream file located at
- and parse the different information.
-
- Args:
- bitstream_path (str): Path where the bitstream is located.
-
- Returns:
- FrameHeader: The parsed info from the bitstream.
- """
-
- ptr = 0
- n_bytes_header = int.from_bytes(
- bitstream[ptr : ptr + 2], byteorder="big", signed=False
- )
- ptr += 2
-
- display_index = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- ptr += 1
-
- raw_arm = int.from_bytes(bitstream[ptr : ptr + 1], byteorder="big", signed=False)
- ptr += 1
- dim_arm = 8 * (raw_arm // (2**4))
- n_hidden_layers_arm = raw_arm % (2**4)
-
- raw_upsampling = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- ptr += 1
- upsampling_kernel_size = raw_upsampling // (2**1)
- static_upsampling_kernel = bool(raw_upsampling % (2**1))
-
- n_hidden_dim_synthesis = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- ptr += 1
- layers_synthesis = []
- for i in range(n_hidden_dim_synthesis):
- out_ft = int.from_bytes(bitstream[ptr : ptr + 1], byteorder="big", signed=False)
- ptr += 1
- kernel_size = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- ptr += 1
- mode_non_linearity = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- ptr += 1
- mode = _POSSIBLE_SYNTHESIS_MODE[mode_non_linearity // 16]
- non_linearity = _POSSIBLE_SYNTHESIS_NON_LINEARITY[mode_non_linearity % 16]
- layers_synthesis.append(f"{out_ft}-{kernel_size}-{mode}-{non_linearity}")
-
- flow_gain = int.from_bytes(bitstream[ptr : ptr + 1], byteorder="big", signed=False)
- ptr += 1
-
- ac_max_val_nn = int.from_bytes(
- bitstream[ptr : ptr + 2], byteorder="big", signed=False
- )
- ptr += 2
- ac_max_val_latent = int.from_bytes(
- bitstream[ptr : ptr + 2], byteorder="big", signed=False
- )
- ptr += 2
- hls_sig_blksize = int.from_bytes(bitstream[ptr: ptr + 1], byteorder='big', signed=True)
- ptr += 1
-
- q_step_index_nn: DescriptorCoolChic = {}
- for nn_name in ["arm", "upsampling", "synthesis"]:
- q_step_index_nn[nn_name] = {}
- for nn_param in ["weight", "bias"]:
- q_step_index_nn[nn_name][nn_param] = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- # # # Hack -- 255 -> -1 for upsampling bias. Indicating no bias.
- # # if q_step_index_nn[nn_name][nn_param] == 255:
- # # q_step_index_nn[nn_name][nn_param] = -1
- # # print("got -1:", nn_name, nn_param)
- ptr += 1
-
- scale_index_nn: DescriptorCoolChic = {}
- for nn_name in ["arm", "upsampling", "synthesis"]:
- scale_index_nn[nn_name] = {}
- for nn_param in ["weight", "bias"]:
- # if q_step_index_nn[nn_name][nn_param] < 0:
- # scale_index_nn[nn_name][nn_param] = -1
- # else:
- scale_index_nn[nn_name][nn_param] = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- ptr += 1
-
- n_bytes_nn: DescriptorCoolChic = {}
- for nn_name in ["arm", "upsampling", "synthesis"]:
- n_bytes_nn[nn_name] = {}
- for nn_param in ["weight", "bias"]:
- # if q_step_index_nn[nn_name][nn_param] < 0:
- # n_bytes_nn[nn_name][nn_param] = -1
- # else:
- n_bytes_nn[nn_name][nn_param] = int.from_bytes(
- bitstream[ptr : ptr + 2], byteorder="big", signed=False
- )
- ptr += 2
-
- latent_n_resolutions = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- ptr += 1
- latent_n_2d_grid = int.from_bytes(
- bitstream[ptr : ptr + 1], byteorder="big", signed=False
- )
- ptr += 1
-
- n_ft_per_latent = []
- for _ in range(latent_n_resolutions):
- n_ft_per_latent.append(
- int.from_bytes(bitstream[ptr : ptr + 1], byteorder="big", signed=False)
- )
- ptr += 1
-
- n_bytes_per_latent = []
- for _ in range(latent_n_2d_grid):
- n_bytes_per_latent.append(
- int.from_bytes(bitstream[ptr : ptr + 3], byteorder="big", signed=False)
- )
- ptr += 3
-
- header_info: FrameHeader = {
- "n_bytes_header": n_bytes_header,
- "latent_n_resolutions": latent_n_resolutions,
- "latent_n_2d_grid": latent_n_2d_grid,
- "n_bytes_per_latent": n_bytes_per_latent,
- "n_ft_per_latent": n_ft_per_latent,
- "n_hidden_layers_arm": n_hidden_layers_arm,
- "dim_arm": dim_arm,
- "upsampling_kernel_size": upsampling_kernel_size,
- "static_upsampling_kernel": static_upsampling_kernel,
- "flow_gain": flow_gain,
- "layers_synthesis": layers_synthesis,
- "q_step_index_nn": q_step_index_nn,
- "scale_index_nn": scale_index_nn,
- "n_bytes_nn": n_bytes_nn,
- "ac_max_val_nn": ac_max_val_nn,
- "ac_max_val_latent": ac_max_val_latent,
- "hls_sig_blksize": hls_sig_blksize,
- "display_index": display_index,
- }
-
- print("\nContent of the frame header:")
- print("------------------------------")
- for k, v in header_info.items():
- print(f"{k:>20}: {v}")
- print(" ------------------------")
-
- return header_info
diff --git a/coolchic/enc/component/coolchic.py b/coolchic/enc/component/coolchic.py
index 13ee545a..dfb6a7e7 100644
--- a/coolchic/enc/component/coolchic.py
+++ b/coolchic/enc/component/coolchic.py
@@ -11,10 +11,12 @@
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, TypedDict
+from enc.visu.console import pretty_string_nn, pretty_string_ups
from torch import nn, Tensor
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table
+
from enc.component.core.arm import (
Arm,
_get_neighbor,
@@ -72,11 +74,13 @@ class CoolChicEncoderParameter:
n_hidden_layers_arm (int, Optional): Number of hidden layers in the
ARM. Set ``n_hidden_layers_arm = 0`` for a linear ARM. Defaults
to 2.
- upsampling_kernel_size (int, Optional): Kernel size for the upsampler.
- See the :doc:`upsampling documentation ` for more
- information. Defaults to 8.
- static_upsampling_kernel (bool, Optional): Set this flag to ``True`` to
- prevent learning the upsampling kernel. Defaults to ``False``.
+ ups_k_size (int, Optional): Upsampling kernel size for the transposed
+ convolutions. See the :doc:`upsampling documentation `
+ for more information. Defaults to 8.
+ ups_preconcat_k_size (int, Optional): Upsampling kernel size for the
+ pre-concatenation convolutions. See the
+ :doc:`upsampling documentation ` for more
+ information. Defaults to 7.
encoder_gain (int, Optional): Multiply the latent by this value before
quantization. See the documentation of Cool-chic forward pass.
Defaults to 16.
@@ -85,9 +89,9 @@ class CoolChicEncoderParameter:
n_ft_per_res: List[int]
dim_arm: int = 24
n_hidden_layers_arm: int = 2
- upsampling_kernel_size: int = 8
- static_upsampling_kernel: bool = False
encoder_gain: int = 16
+ ups_k_size: int = 8
+ ups_preconcat_k_size: int = 7
# ==================== Not set by the init function ===================== #
#: Automatically computed, number of different latent resolutions
@@ -192,7 +196,13 @@ def __init__(self, param: CoolChicEncoderParameter):
# ===================== Upsampling stuff ===================== #
self.upsampling = Upsampling(
- self.param.upsampling_kernel_size, self.param.static_upsampling_kernel
+ ups_k_size=self.param.ups_k_size,
+ ups_preconcat_k_size=self.param.ups_preconcat_k_size,
+ # Instantiate one different upsampling and pre-concatenation
+ # filters for each of the upsampling step. Could also be set to one
+ # to share the same filter across all latents.
+ n_ups_kernel=self.param.latent_n_grids - 1,
+ n_ups_preconcat_kernel=self.param.latent_n_grids - 1,
)
# ===================== Upsampling stuff ===================== #
@@ -239,17 +249,19 @@ def __init__(self, param: CoolChicEncoderParameter):
self.arm = Arm(self.param.dim_arm, self.param.n_hidden_layers_arm)
# ===================== ARM related stuff ==================== #
+ # Something like ['arm', 'synthesis', 'upsampling']
+ self.modules_to_send = [tmp.name for tmp in fields(DescriptorCoolChic)]
+
# ======================== Monitoring ======================== #
# Pretty string representing the decoder complexity
self.flops_str = ""
# Total number of multiplications to decode the image
self.total_flops = 0.0
+ self.flops_per_module = {k: 0 for k in self.modules_to_send}
# Fill the two attributes aboves
self.get_flops()
# ======================== Monitoring ======================== #
- # Something like ['arm', 'synthesis', 'upsampling']
- self.modules_to_send = [tmp.name for tmp in fields(DescriptorCoolChic)]
# Track the quantization step of each neural network, None if the
# module is not yet quantized
@@ -413,6 +425,7 @@ def forward(
additional_data["detailed_log_scale"] = []
additional_data["detailed_rate_bit"] = []
additional_data["detailed_centered_latent"] = []
+ additional_data["hpfilters"] = []
# "Pointer" for the reading of the 1D scale, mu and rate
cnt = 0
@@ -586,6 +599,9 @@ def get_flops(self) -> None:
# print("Ignoring get_flops")
# Count the number of floating point operations here. It must be done before
# torch scripting the different modules.
+
+ self = self.train(mode=False)
+
flops = FlopCountAnalysis(
self,
(
@@ -601,9 +617,14 @@ def get_flops(self) -> None:
flops.uncalled_modules_warnings(False)
self.total_flops = flops.total()
+ for k in self.flops_per_module:
+ self.flops_per_module[k] = flops.by_module()[k]
+
self.flops_str = flop_count_table(flops)
del flops
+ self = self.train(mode=True)
+
def get_network_rate(self) -> DescriptorCoolChic:
"""Return the rate (in bits) associated to the parameters
(weights and biases) of the different modules
@@ -706,3 +727,63 @@ def to_device(self, device: POSSIBLE_DEVICE) -> None:
if hasattr(layer, "qb"):
if layer.qb is not None:
self.arm.mlp[idx_layer].qb = layer.qb.to(device)
+
+ def pretty_string(self) -> str:
+ """Get a pretty string representing the layer of a ``CoolChicEncoder``"""
+
+ s = ""
+
+ if not self.flops_str:
+ self.get_flops()
+
+ n_pixels = self.param.img_size[-2] * self.param.img_size[-1]
+ total_mac_per_pix = self.get_total_mac_per_pixel()
+
+
+ title = f"Cool-chic architecture {total_mac_per_pix:.0f} MAC / pixel"
+ s += (
+ f"\n{title}\n"
+ f"{'-' * len(title)}\n\n"
+ )
+
+ complexity = self.flops_per_module['upsampling'] / n_pixels
+ share_complexity = 100 * complexity / total_mac_per_pix
+ title = f"Upsampling {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity"
+ s += (
+ f"{title}\n"
+ f"{'=' * len(title)}\n"
+ "Note: all upsampling layers are separable and symmetric "
+ "(transposed) convolutions.\n\n"
+
+ )
+ s += pretty_string_ups(self.upsampling, "")
+
+ complexity = self.flops_per_module['arm'] / n_pixels
+ share_complexity = 100 * complexity / total_mac_per_pix
+ title = f"ARM {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity"
+ s += (
+ f"\n\n\n{title}\n"
+ f"{'=' * len(title)}\n\n\n"
+
+ )
+ input_arm = f"{self.arm.dim_arm}-pixel context"
+ output_arm = "mu, log scale"
+ s += pretty_string_nn(
+ self.arm.mlp, "", input_arm, output_arm
+ )
+
+ complexity = self.flops_per_module['synthesis'] / n_pixels
+ share_complexity = 100 * complexity / total_mac_per_pix
+ title = f"Synthesis {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity"
+ s += (
+ f"\n\n\n{title}\n"
+ f"{'=' * len(title)}\n\n\n"
+
+ )
+ input_syn = f"{self.synthesis.input_ft} features"
+ output_syn = "Decoded image"
+ s += pretty_string_nn(
+ self.synthesis.layers, "", input_syn, output_syn
+ )
+
+ return s
diff --git a/coolchic/enc/component/core/arm.py b/coolchic/enc/component/core/arm.py
index c5999faf..0878fd4f 100644
--- a/coolchic/enc/component/core/arm.py
+++ b/coolchic/enc/component/core/arm.py
@@ -50,6 +50,8 @@ def __init__(
super().__init__()
self.residual = residual
+ self.in_channels = in_channels
+ self.out_channels = out_channels
# -------- Instantiate empty parameters, set by the initialize function
self.weight = nn.Parameter(
@@ -95,94 +97,6 @@ def forward(self, x: Tensor) -> Tensor:
else:
return F.linear(x, self.weight, bias=self.bias)
-class ArmIntLinear(nn.Module):
- """Create a Linear layer of the Auto-Regressive Module (ARM). This is a
- wrapper around the usual ``nn.Linear`` layer of PyTorch, with a custom
- initialization. It performs the following operations:
-
- * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b}` if
- ``residual`` is ``False``
-
- * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b} +
- \\mathbf{x}_{in}` if ``residual`` is ``True``.
-
- The input :math:`\\mathbf{x}_{in}` is a :math:`[B, C_{in}]` tensor, the
- output :math:`\\mathbf{x}_{out}` is a :math:`[B, C_{out}]` tensor.
-
- The layer weight and bias shapes are :math:`\\mathbf{W} \\in
- \\mathbb{R}^{C_{out} \\times C_{in}}` and :math:`\\mathbf{b} \\in
- \\mathbb{R}^{C_{out}}`.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- fpfm: int = 0,
- pure_int: bool = False,
- residual: bool = False,
- ):
- """
- Args:
- in_channels: Number of input features :math:`C_{in}`.
- out_channels: Number of output features :math:`C_{out}`.
- fpfm: Internal stuff for integer computation. **No need to modify
- this**. Defaults to 0.
- residual: True to add a residual connexion to the layer. Defaults to
- False.
- """
-
- super().__init__()
-
- self.fpfm = fpfm
- self.pure_int = pure_int
- self.residual = residual
-
- # -------- Instantiate empty parameters, set by a later load
- if self.pure_int:
- self.weight = nn.Parameter(
- torch.empty((out_channels, in_channels), dtype=torch.int32), requires_grad=False
- )
- self.bias = nn.Parameter(torch.empty((out_channels), dtype=torch.int32), requires_grad=False)
- else:
- self.weight = nn.Parameter(
- torch.empty((out_channels, in_channels), dtype=torch.float), requires_grad=False
- )
- self.bias = nn.Parameter(torch.empty((out_channels), dtype=torch.float), requires_grad=False)
- # -------- Instantiate empty parameters, set by a later load
-
-
- def forward(self, x: Tensor) -> Tensor:
- """Perform the forward pass of this layer.
-
- Args:
- x: Input tensor of shape :math:`[B, C_{in}]`.
-
- Returns:
- Tensor with shape :math:`[B, C_{out}]`.
- """
- if self.residual:
- xx = F.linear(x, self.weight, bias=self.bias) + x*self.fpfm
- else:
- xx = F.linear(x, self.weight, bias=self.bias)
-
- # Renorm by fpfm after our (x*fpfm)*(qw*fpfm) multiplication.
- # WE MAKE INTEGER DIVISION OBEY C++ (TO-ZERO) SEMANTICS, NOT PYTHON (TO-NEGATIVE-INFINITY) SEMANTICS
- if self.pure_int:
- xx = xx + torch.sign(xx)*self.fpfm//2
- # We separate out -ve and non-ve.
- neg_result = -((-xx)//self.fpfm)
- pos_result = xx//self.fpfm
- result = torch.where(xx < 0, neg_result, pos_result)
- else:
- xx = xx + torch.sign(xx)*self.fpfm/2
- # We separate out -ve and non-ve.
- neg_result = -((-xx)/self.fpfm)
- pos_result = xx/self.fpfm
- result = torch.where(xx < 0, neg_result, pos_result)
- result = result.to(torch.int32).to(torch.float)
-
- return result
class Arm(nn.Module):
"""Instantiate an autoregressive probability module, modelling the
@@ -223,7 +137,6 @@ class Arm(nn.Module):
from the ``ArmLinear`` class.
"""
-
def __init__(self, dim_arm: int, n_hidden_layers_arm: int):
"""
Args:
@@ -238,6 +151,7 @@ def __init__(self, dim_arm: int, n_hidden_layers_arm: int):
f"ARM context size and hidden layer dimension must be "
f"a multiple of 8. Found {dim_arm}."
)
+ self.dim_arm = dim_arm
# ======================== Construct the MLP ======================== #
layers_list = nn.ModuleList()
@@ -317,165 +231,6 @@ def reinitialize_parameters(self) -> None:
if isinstance(layer, ArmLinear):
layer.initialize_parameters()
-class ArmInt(nn.Module):
- """Instantiate an autoregressive probability module, modelling the
- conditional distribution :math:`p_{\\psi}(\\hat{y}_i \\mid
- \\mathbf{c}_i)` of a (quantized) latent pixel :math:`\\hat{y}_i`,
- conditioned on neighboring already decoded context pixels
- :math:`\\mathbf{c}_i \in \\mathbb{Z}^C`, where :math:`C` denotes the
- number of context pixels.
-
- The distribution :math:`p_{\\psi}` is assumed to follow a Laplace
- distribution, parameterized by an expectation :math:`\\mu` and a scale
- :math:`b`, where the scale and the variance :math:`\\sigma^2` are
- related as follows :math:`\\sigma^2 = 2 b ^2`.
-
- The parameters of the Laplace distribution for a given latent pixel
- :math:`\\hat{y}_i` are obtained by passing its context pixels
- :math:`\\mathbf{c}_i` through an MLP :math:`f_{\\psi}`:
-
- .. math::
-
- p_{\\psi}(\\hat{y}_i \\mid \\mathbf{c}_i) \sim \mathcal{L}(\\mu_i,
- b_i), \\text{ where } \\mu_i, b_i = f_{\\psi}(\\mathbf{c}_i).
-
- .. attention::
-
- The MLP :math:`f_{\\psi}` has a few constraint on its architecture:
-
- * The width of all hidden layers (i.e. the output of all layers except
- the final one) are identical to the number of pixel contexts
- :math:`C`;
-
- * All layers except the last one are residual layers, followed by a
- ``ReLU`` non-linearity;
-
- * :math:`C` must be at a multiple of 8.
-
- The MLP :math:`f_{\\psi}` is made of custom Linear layers instantiated
- from the ``ArmLinear`` class.
- """
-
- def __init__(self, dim_arm: int, n_hidden_layers_arm: int, fpfm: int, pure_int: bool):
- """
- Args:
- dim_arm: Number of context pixels AND dimension of all hidden
- layers :math:`C`.
- n_hidden_layers_arm: Number of hidden layers. Set it to 0 for
- a linear ARM.
- """
- super().__init__()
-
- assert dim_arm % 8 == 0, (
- f"ARM context size and hidden layer dimension must be "
- f"a multiple of 8. Found {dim_arm}."
- )
-
- self.FPFM = fpfm # fixed-point: multiplication to get int.
- self.pure_int = pure_int # weights and biases are actual int (cpu only), or just int values in floats (gpu friendly).
-
- # ======================== Construct the MLP ======================== #
- layers_list = nn.ModuleList()
-
- # Construct the hidden layer(s)
- for i in range(n_hidden_layers_arm):
- layers_list.append(ArmIntLinear(dim_arm, dim_arm, self.FPFM, self.pure_int, residual=True))
- layers_list.append(nn.ReLU())
-
- # Construct the output layer. It always has 2 outputs (mu and scale)
- layers_list.append(ArmIntLinear(dim_arm, 2, self.FPFM, self.pure_int, residual=False))
- self.mlp = nn.Sequential(*layers_list)
- # ======================== Construct the MLP ======================== #
-
- def set_param_from_float(self, float_param: OrderedDict[str, Tensor]) -> None:
- # We take floating point values here, and convert them to ints.
-
- # floating point params. We convert to fixed-point integer and store them.
- integerised_param = {}
- for k in float_param:
- if "weight" in k:
- float_v = float_param[k]*self.FPFM
- else:
- float_v = float_param[k]*self.FPFM*self.FPFM
-
- float_v = float_v + torch.sign(float_v)*0.5
- neg_result = -(-float_v).to(torch.int32)
- pos_result = float_v.to(torch.int32)
- int_v = torch.where(float_v < 0, neg_result, pos_result)
- if not self.pure_int:
- int_v = int_v.to(torch.float)
- integerised_param[k] = nn.parameter.Parameter(int_v, requires_grad=False)
-
- self.load_state_dict(integerised_param, assign=True)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
- """Perform the auto-regressive module (ARM) forward pass. The ARM takes
- as input a tensor of shape :math:`[B, C]` i.e. :math:`B` contexts with
- :math:`C` context pixels. ARM outputs :math:`[B, 2]` values correspond
- to :math:`\\mu, b` for each of the :math:`B` input pixels.
-
- .. warning::
-
- Note that the ARM expects input to be flattened i.e. spatial
- dimensions :math:`H, W` are collapsed into a single batch-like
- dimension :math:`B = HW`, leading to an input of shape
- :math:`[B, C]`, gathering the :math:`C` contexts for each of the
- :math:`B` pixels to model.
-
- .. note::
-
- The ARM MLP does not output directly the scale :math:`b`. Denoting
- :math:`s` the raw output of the MLP, the scale is obtained as
- follows:
-
- .. math::
-
- b = e^{x - 4}
-
- Args:
- x: Concatenation of all input contexts
- :math:`\\mathbf{c}_i`. Tensor of shape :math:`[B, C]`.
-
- Returns:
- Concatenation of all Laplace distributions param :math:`\\mu, b`.
- Tensor of shape :math:([B]). Also return the *log scale*
- :math:`s` as described above. Tensor of shape :math:`(B)`
- """
- xint = x.clone().detach()
- xint = xint*self.FPFM
- if self.pure_int:
- xint = xint.to(torch.int32)
-
- for idx_l, layer in enumerate(self.mlp.children()):
- xint = layer(xint)
-
- # float the result.
- raw_proba_param = xint / self.FPFM
-
- mu = raw_proba_param[:, 0]
- log_scale = raw_proba_param[:, 1]
-
- # no scale smaller than exp(-4.6) = 1e-2 or bigger than exp(5.01) = 150
- scale = torch.exp(torch.clamp(log_scale - 4, min=-4.6, max=5.0))
-
- return mu, scale, log_scale
-
- def get_param(self) -> OrderedDict[str, Tensor]:
- """Return **a copy** of the weights and biases inside the module.
-
- Returns:
- A copy of all weights & biases in the layers.
- """
- # Detach & clone to create a copy
- return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()})
-
- def set_param(self, param: OrderedDict[str, Tensor]) -> None:
- """Replace the current parameters of the module with param.
-
- Args:
- param: Parameters to be set.
- """
- self.load_state_dict(param)
@torch.jit.script
def _get_neighbor(x: Tensor, mask_size: int, non_zero_pixel_ctx_idx: Tensor) -> Tensor:
@@ -565,7 +320,7 @@ def _get_non_zero_pixel_ctx_index(dim_arm: int) -> Tensor:
Returns:
Tensor: 1D tensor with the flattened index of the context pixels.
"""
-
+ # fmt: off
if dim_arm == 8:
return torch.tensor(
[ 13,
@@ -606,3 +361,4 @@ def _get_non_zero_pixel_ctx_index(dim_arm: int) -> Tensor:
36, 37, 38, 39, #
]
)
+ # fmt: on
diff --git a/coolchic/enc/component/core/quantizer.py b/coolchic/enc/component/core/quantizer.py
index 6193c2f7..25646187 100644
--- a/coolchic/enc/component/core/quantizer.py
+++ b/coolchic/enc/component/core/quantizer.py
@@ -169,7 +169,6 @@ def quantize(
Quantized tensor
"""
# ----- Check user input
- # TODO: How long is it to do such assert?
assert quantizer_noise_type in typing.get_args(POSSIBLE_QUANTIZATION_NOISE_TYPE), (
f"quantizer_noise_type must be in {POSSIBLE_QUANTIZATION_NOISE_TYPE}"
f" found {quantizer_noise_type}"
@@ -226,7 +225,6 @@ def quantize(
# From the forward point of view (i.e. entering into the torch.no_grad()), we have
# y = softround(x) - softround(x) + round(x) = round(x). From the backward point of view
# we have y = softround(x) meaning that dy / dx = d softround(x) / dx.
- # TODO: check whether it works?
y = softround(x, soft_round_temperature)
with torch.no_grad():
y = y - softround(x, soft_round_temperature) + torch.round(x)
diff --git a/coolchic/enc/component/core/synthesis.py b/coolchic/enc/component/core/synthesis.py
index 0d56ad06..a4440fca 100644
--- a/coolchic/enc/component/core/synthesis.py
+++ b/coolchic/enc/component/core/synthesis.py
@@ -46,8 +46,11 @@ def __init__(
"""
super().__init__()
- self.pad = int((kernel_size - 1) / 2)
self.residual = residual
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.pad = int((kernel_size - 1) / 2)
# -------- Instantiate empty parameters, set by the initialize function
self.groups = 1 # Hardcoded for now
@@ -168,6 +171,9 @@ def __init__(self, input_ft: int, layers_dim: List[str]):
following the notation detailed above.
"""
super().__init__()
+
+ self.synth_branches = nn.ModuleList()
+ self.input_ft = input_ft
layers_list = nn.ModuleList()
# Construct the hidden layer(s)
@@ -211,6 +217,7 @@ def forward(self, x: Tensor) -> Tensor:
"""
return self.layers(x)
+
def get_param(self) -> OrderedDict[str, Tensor]:
"""Return **a copy** of the weights and biases inside the module.
@@ -229,7 +236,7 @@ def set_param(self, param: OrderedDict[str, Tensor]):
self.load_state_dict(param)
def reinitialize_parameters(self) -> None:
- """Re-initialize in place the parameters of all the SynthesisConv2d layer."""
+ """Re-initialize in place the params of all the ``SynthesisConv2d`` layers."""
for layer in self.layers.children():
if isinstance(layer, SynthesisConv2d):
layer.initialize_parameters()
diff --git a/coolchic/enc/component/core/upsampling.py b/coolchic/enc/component/core/upsampling.py
index f74ae844..c3b9b4f5 100644
--- a/coolchic/enc/component/core/upsampling.py
+++ b/coolchic/enc/component/core/upsampling.py
@@ -11,127 +11,293 @@
import torch
import torch.nn.functional as F
+import torch.nn.utils.parametrize as parametrize
from einops import rearrange
from torch import Tensor, nn
-class UpsamplingConvTranspose2d(nn.Module):
- """Wrapper around the usual ``nn.TransposeConv2d`` layer. It performs a 2x
- upsampling of a latent variable with a **single** input and output channel.
- It can be learned or not, depending on the flag
- ``static_upsampling_kernel``. Its initialization depends on the requested
- kernel size. If the kernel size is 4 or 6, we use the bilinear kernel with
- zero padding if necessary. Otherwise, if the kernel size is 8 or bigger, we
- rely on the bicubic kernel.
- """
+class _Parameterization_Symmetric_1d(nn.Module):
+ """This module is not meant to be instantiated. It should rather be used
+ through the ``torch.nn.utils.parametrize.register_parametrization()``
+ function to reparameterize a N-element vector into a 2N-element (or 2N+1)
+ symmetric vector. For instance:
- kernel_bilinear = torch.tensor(
- [
- [0.0625, 0.1875, 0.1875, 0.0625],
- [0.1875, 0.5625, 0.5625, 0.1875],
- [0.1875, 0.5625, 0.5625, 0.1875],
- [0.0625, 0.1875, 0.1875, 0.0625],
- ]
- )
-
- kernel_bicubic = torch.tensor(
- [
- [ 0.0012359619 , 0.0037078857 ,-0.0092010498 ,-0.0308990479 ,-0.0308990479 ,-0.0092010498 , 0.0037078857 , 0.0012359619],
- [ 0.0037078857 , 0.0111236572 ,-0.0276031494 ,-0.0926971436 ,-0.0926971436 ,-0.0276031494 , 0.0111236572 , 0.0037078857],
- [-0.0092010498 ,-0.0276031494 , 0.0684967041 , 0.2300262451 , 0.2300262451 , 0.0684967041 ,-0.0276031494 ,-0.0092010498],
- [-0.0308990479 ,-0.0926971436 , 0.2300262451 , 0.7724761963 , 0.7724761963 , 0.2300262451 ,-0.0926971436 ,-0.0308990479],
- [-0.0308990479 ,-0.0926971436 , 0.2300262451 , 0.7724761963 , 0.7724761963 , 0.2300262451 ,-0.0926971436 ,-0.0308990479],
- [-0.0092010498 ,-0.0276031494 , 0.0684967041 , 0.2300262451 , 0.2300262451 , 0.0684967041 ,-0.0276031494 ,-0.0092010498],
- [ 0.0037078857 , 0.0111236572 ,-0.0276031494 ,-0.0926971436 ,-0.0926971436 ,-0.0276031494 , 0.0111236572 , 0.0037078857],
- [ 0.0012359619 , 0.0037078857 ,-0.0092010498 ,-0.0308990479 ,-0.0308990479 ,-0.0092010498 , 0.0037078857 , 0.0012359619],
- ]
- )
+ * x = a b c and target_k_size = 5 --> a b c b a
+ * x = a b c and target_k_size = 6 --> a b c c b a
+ Both these 5-element or 6-element vectors can be parameterize through
+ a 3-element representation (a, b, c).
+ """
- def __init__(
- self,
- upsampling_kernel_size: int,
- static_upsampling_kernel: bool
- ):
+ def __init__(self, target_k_size: int):
"""
Args:
- upsampling_kernel_size: Upsampling kernel size. Should be >= 4
- and a multiple of two.
- static_upsampling_kernel: If true, don't learn the upsampling kernel.
+ target_k_size: Target size of the kernel after reparameterization.
"""
- super().__init__()
- assert upsampling_kernel_size >= 4, (
- f"Upsampling kernel size should be >= 4." f"Found {upsampling_kernel_size}"
+ super().__init__()
+ self.target_k_size = target_k_size
+ self.param_size = _Parameterization_Symmetric_1d.size_param_from_target(
+ self.target_k_size
)
- assert upsampling_kernel_size % 2 == 0, (
- f"Upsampling kernel size should be even." f"Found {upsampling_kernel_size}"
+ def forward(self, x: Tensor) -> Tensor:
+ """Return a longer, symmetric vector by concatenating x with a flipped
+ version of itself.
+
+ Args:
+ x (Tensor): [N] tensor.
+
+ Returns:
+ Tensor: [2N] or [2N + 1] tensor, depending on self.target_k_size
+ """
+
+ # torch.fliplr requires to have a 2D kernel
+ x_reversed = torch.fliplr(x.view(1, -1)).view(-1)
+
+ kernel = torch.cat(
+ [
+ x,
+ # a b c c b a if n is even or a b c b a if n is odd
+ x_reversed[self.target_k_size % 2 :],
+ ],
)
- self.upsampling_kernel_size = upsampling_kernel_size
- self.static_upsampling_kernel = static_upsampling_kernel
+ return kernel
+
+
+ @classmethod
+ def size_param_from_target(cls, target_k_size: int) -> int:
+ """Return the size of the appropriate parameterization of a
+ symmetric tensor with target_k_size elements. For instance:
+
+ target_k_size = 6 ; parameterization size = 3 e.g. (a b c c b a)
+
+ target_k_size = 7 ; parameterization size = 4 e.g. (a b c d c b a)
+
+ Args:
+ target_k_size (int): Size of the actual symmetric 1D kernel.
+
+ Returns:
+ int: Size of the underlying parameterization.
+ """
+ # For a kernel of size target_k_size = 2N, we need N values
+ # e.g. 3 params a b c to parameterize a b c c b a.
+ # For a kernel of size target_k_size = 2N + 1, we need N + 1 values
+ # e.g. 4 params a b c d to parameterize a b c d c b a.
+ return (target_k_size + 1) // 2
+
+
+
+class UpsamplingSeparableSymmetricConv2d(nn.Module):
+ """
+ A conv2D which has a separable and symmetric *odd* kernel.
+
+ Separable means that the 2D-kernel :math:`\mathbf{w}_{2D}` can be expressed
+ as the outer product of a 1D kernel :math:`\mathbf{w}_{1D}`:
+
+ .. math::
+
+ \mathbf{w}_{2D} = \mathbf{w}_{1D} \otimes \mathbf{w}_{1D}.
+
+ The 1D kernel :math:`\mathbf{w}_{1D}` is also symmetric. That is, the 1D
+ kernel is something like :math:`\mathbf{w}_{1D} = \left(a\ b\ c\ b\ a\
+ \\right).`
+
+ The symmetric constraint is obtained through the module
+ ``_Parameterization_Symmetric_1d``. The separable constraint is obtained by
+ calling twice the 1D kernel.
+ """
+ def __init__(self, kernel_size: int):
+ """
+ kernel_size: Size of the kernel :math:`\mathbf{w}_{1D}` e.g. 7 to
+ obtain a symmetrical, separable 7x7 filter. Must be odd!
+ """
+ super().__init__()
+
+ assert (
+ kernel_size % 2 == 1
+ ), f"Upsampling kernel size must be odd, found {kernel_size}."
+
+ self.target_k_size = kernel_size
+ self.param_size = _Parameterization_Symmetric_1d.size_param_from_target(
+ self.target_k_size
+ )
# -------- Instantiate empty parameters, set by the initialize function
self.weight = nn.Parameter(
- torch.empty(1, 1, upsampling_kernel_size, upsampling_kernel_size),
- requires_grad=True,
+ torch.empty(self.param_size), requires_grad=True
)
- self.bias = nn.Parameter(torch.empty((1)), requires_grad=True)
+
+ self.bias = nn.Parameter(torch.empty(1), requires_grad=True)
self.initialize_parameters()
# -------- Instantiate empty parameters, set by the initialize function
- # Keep initial weights if required by the self.static_upsampling kernel flag
- if self.static_upsampling_kernel:
- # register_buffer for automatic device management. We set persistent to false
- # to simply use the "automatically move to device" function, without
- # considering non_zero_pixel_ctx_index as a parameters (i.e. returned
- # by self.parameters())
- self.register_buffer("static_kernel", self.weight.data.clone(), persistent=False)
- else:
- self.static_kernel = None
+ # Each time we call .weight, we'll call the forward of
+ # _Parameterization_Symmetric_1d to get a symmetric kernel.
+ parametrize.register_parametrization(
+ self,
+ "weight",
+ _Parameterization_Symmetric_1d(target_k_size=self.target_k_size),
+ # Unsafe because we change the data dimension, from N to 2N + 1
+ unsafe=True,
+ )
def initialize_parameters(self) -> None:
"""
- Initialize **in-place ** the weights and the biases of the transposed
- convolution layer performing the upsampling.
+ Initialize the weights and the biases of the transposed convolution
+ layer performing the upsampling.
- - Biases are always set to zero.
+ * Biases are always set to zero.
- - Weights are set to a (padded) bicubic kernel if kernel size is at
- least 8. If kernel size is greater than or equal to 4, weights are
- set to a (padded) bilinear kernel.
+ * Weights are set to :math:`(0,\ 0,\ 0,\ \ldots, 1)` so that when the
+ symmetric reparameterization is applied a Dirac kernel is obtained e.g.
+ :math:`(0,\ 0,\ 0,\ \ldots, 1, \ldots, 0,\ 0,\ 0,)`.
"""
- # -------- bias is always set to zero (and in fact never ever used)
+ # Zero everywhere except for the last coef
+ w = torch.zeros_like(self.weight)
+ w[-1] = 1
+ self.weight = nn.Parameter(w, requires_grad=True)
+
self.bias = nn.Parameter(torch.zeros_like(self.bias), requires_grad=True)
- # -------- Weights are initialized to bicubic or bilinear
- # adapted filter size
- K = self.upsampling_kernel_size
- self.upsampling_padding = (K // 2, K // 2, K // 2, K // 2)
- self.upsampling_crop = (3 * K - 2) // 2
+ def forward(self, x: Tensor) -> Tensor:
+ """Perform a "normal" 2D convolution, except that the underlying kernel
+ is both separable & symmetrical. The actual implementation of the forward
+ depends on ``self.training``.
+
+ If we're training, we use a non-separable implementation. That is, we
+ first compute the 2D kernel through an outer product and then use a
+ single 2D convolution. This is more stable.
- if K < 8:
- kernel_init = UpsamplingConvTranspose2d.kernel_bilinear
+ If we're not training, we use two successive 1D convolutions.
+
+ .. warning::
+
+ There is a residual connexion in the forward.
+
+ Args:
+ x: [B, 1, H, W] tensor to be filtered. Must have one
+ only channel.
+
+ Returns:
+ Tensor: Filtered tensor [B, 1, H, W].
+ """
+ k = self.weight.size()[0]
+ weight = self.weight.view(1, -1)
+ padding = k // 2
+
+ # Train using non-separable (more stable)
+ if self.training:
+ # Kronecker product of (1 k) & (k 1) --> (k, k).
+ # Then, two dummy dimensions are added to be compliant with conv2d
+ # (k, k) --> (1, 1, k, k).
+ kernel_2d = torch.kron(weight, weight.T).view((1, 1, k, k))
+
+ # ! Note the residual connexion!
+ return F.conv2d(x, kernel_2d, bias=None, stride=1, padding=padding) + x
+
+ # Test through separable (less complex, for the flop counter)
else:
- kernel_init = UpsamplingConvTranspose2d.kernel_bicubic
-
- # pad initial filter according to desired kernel size
- tmpad = (K - kernel_init.size()[0]) // 2
- upsampling_kernel = F.pad(
- kernel_init.clone().detach(),
- (tmpad, tmpad, tmpad, tmpad),
- mode="constant",
- value=0.0,
+ yw = F.conv2d(x, weight.view((1, 1, 1, k)), padding=(0, padding))
+
+ # ! Note the residual connexion!
+ return F.conv2d(yw, weight.view((1, 1, k, 1)), padding=(padding, 0)) + x
+
+
+class UpsamplingSeparableSymmetricConvTranspose2d(nn.Module):
+ """
+ A TransposedConv2D which has a separable and symmetric *even* kernel.
+
+ Separable means that the 2D-kernel :math:`\mathbf{w}_{2D}` can be expressed
+ as the outer product of a 1D kernel :math:`\mathbf{w}_{1D}`:
+
+ .. math::
+
+ \mathbf{w}_{2D} = \mathbf{w}_{1D} \otimes \mathbf{w}_{1D}.
+
+ The 1D kernel :math:`\mathbf{w}_{1D}` is also symmetric. That is, the 1D
+ kernel is something like :math:`\mathbf{w}_{1D} = \left(a\ b\ c\ c\ b\ a\
+ \\right).`
+
+ The symmetric constraint is obtained through the module
+ ``_Parameterization_Symmetric_1d``. The separable constraint is obtained by
+ calling twice the 1D kernel.
+ """
+
+ def __init__(self, kernel_size: int):
+ """
+ Args:
+ kernel_size: Upsampling kernel size. Shall be even and >= 4.
+ """
+ super().__init__()
+
+ assert kernel_size >= 4 and not kernel_size % 2, (
+ f"Upsampling kernel size shall be even and โฅ4. Found {kernel_size}"
)
- # 4D kernel to be compatible with transpose convolution
- upsampling_kernel = rearrange(upsampling_kernel, "k_h k_w -> 1 1 k_h k_w")
- self.weight = nn.Parameter(upsampling_kernel, requires_grad=True)
+ self.target_k_size = kernel_size
+ self.param_size = _Parameterization_Symmetric_1d.size_param_from_target(
+ self.target_k_size
+ )
+
+ # -------- Instantiate empty parameters, set by the initialize function
+ self.weight = nn.Parameter(
+ torch.empty(self.param_size), requires_grad=True
+ )
+
+ self.bias = nn.Parameter(torch.empty(1), requires_grad=True)
+ self.initialize_parameters()
+ # -------- Instantiate empty parameters, set by the initialize function
+
+ # Each time we call .weight, we'll call the forward of
+ # _Parameterization_Symmetric_1d to get a symmetric kernel.
+ parametrize.register_parametrization(
+ self,
+ "weight",
+ _Parameterization_Symmetric_1d(target_k_size=self.target_k_size),
+ # Unsafe because we change the data dimension, from N to 2N + 1
+ unsafe=True,
+ )
+
+ def initialize_parameters(self) -> None:
+ """Initialize the parameters of a
+ ``UpsamplingSeparableSymmetricConvTranspose2d`` layer.
+
+ * Biases are always set to zero.
+
+ * Weights are initialize as a (possibly padded) bilinear filter when
+ ``target_k_size`` is 4 or 6, otherwise a bicubic filter is used.
+ """
+ # For a target kernel size of 4 or 6, we use a bilinear kernel as the
+ # initialization. For bigger kernels, a bicubic kernel is used. In both
+ # case we just initialize the left half of the kernel since these
+ # filters are symmetrical
+ if self.target_k_size < 8:
+ kernel_core = torch.tensor([1.0 / 4.0, 3.0 / 4.0])
+ else:
+ kernel_core = torch.tensor([0.0351562, 0.1054687, -0.2617187, -0.8789063])
+
+ # If target_k_size = 6, then param_size = 3 while kernel_core = 2
+ # Thus we need to add zero_pad = 1 to the left of the kernel.
+ zero_pad = self.param_size - kernel_core.size()[0]
+ w = torch.zeros_like(self.weight)
+ w[zero_pad:] = kernel_core
+ self.weight = nn.Parameter(w, requires_grad=True)
+
+ self.bias = nn.Parameter(torch.zeros_like(self.bias), requires_grad=True)
def forward(self, x: Tensor) -> Tensor:
"""Perform the spatial upsampling (with scale 2) of an input with a
- single channel.
+ single channel. Note that the upsampling filter is both symmetrical and
+ separable. The actual implementation of the forward depends on
+ ``self.training``.
+
+ If we're training, we use a non-separable implementation. That is, we
+ first compute the 2D kernel through an outer product and then use a
+ single 2D convolution. This is more stable.
+
+ If we're not training, we use two successive 1D convolutions.
Args:
x: Single channel input with shape :math:`(B, 1, H, W)`
@@ -139,23 +305,47 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
Upsampled version of the input with shape :math:`(B, 1, 2H, 2W)`
"""
- upsampling_weight = (
- self.static_kernel if self.static_upsampling_kernel else self.weight
- )
- x_pad = F.pad(x, self.upsampling_padding, mode="replicate")
- y_conv = F.conv_transpose2d(x_pad, upsampling_weight, stride=2)
+ k = self.target_k_size # kernel size
+ P0 = k // 2 # could be 0 or k//2 as in legacy implementation
+ C = 2 * P0 - 1 + k // 2 # crop side border k - 1 + k//2 (k=4, C=5 k=8, C=11)
+
+ weight = self.weight.view(1, -1)
+
+ if self.training: # training using non-separable (more stable)
+ kernel_2d = (torch.kron(weight, weight.T).view((1, 1, k, k)))
+
+ x_pad = F.pad(x, (P0, P0, P0, P0), mode="replicate")
+ yc = F.conv_transpose2d(x_pad, kernel_2d, stride=2)
+
+ # crop to remove padding in convolution
+ H, W = yc.size()[-2:]
+ y = yc[
+ :,
+ :,
+ C : H - C,
+ C : W - C,
+ ]
+
+ else: # testing through separable (less complex)
+ # horizontal filtering
+ x_pad = F.pad(x, (P0, P0, 0, 0), mode="replicate")
+ yc = F.conv_transpose2d(x_pad, weight.view((1, 1, 1, k)), stride=(1, 2))
+ W = yc.size()[-1]
+ y = yc[
+ :,
+ :,
+ :,
+ C : W - C,
+ ]
- # crop to remove padding in convolution
- H, W = y_conv.size()[-2:]
- results = y_conv[
- :,
- :,
- self.upsampling_crop : H - self.upsampling_crop,
- self.upsampling_crop : W - self.upsampling_crop,
- ]
+ # vertical filtering
+ x_pad = F.pad(y, (0, 0, P0, P0), mode="replicate")
+ yc = F.conv_transpose2d(x_pad, weight.view((1, 1, k, 1)), stride=(2, 1))
+ H = yc.size()[-2]
+ y = yc[:, :, C : H - C, :]
- return results
+ return y
class Upsampling(nn.Module):
@@ -176,42 +366,103 @@ class Upsampling(nn.Module):
\hat{\mathbf{z}} \\in \\mathbb{R}^{C \\times H \\times W} \\text {
and } C = \\sum_i C_i.
- The upsampling relies on a single custom transpose convolution
- ``UpsamplingConvTranspose2d`` performing a 2x upsampling of a 1-channel
- input. This transpose convolution is called over and over to upsampling
- each channel of each resolution until they reach the required :math:`H
- \\times W` dimensions.
+ For a toy example with 3 latent grids (``--n_ft_per_res=1,1,1``), the
+ overall diagram of the upsampling is as follows.
+
+ .. code::
+
+ +---------+
+ y0 -> | TConv2d | -----+
+ +---------+ |
+ v
+ +--------+ +-----+ +---------+
+ y1 -> | Conv2d | -> | cat | -> | TConv2d | -----+
+ +--------+ +-----+ +---------+ |
+ v
+ +--------+ +-----+ +---------+
+ y2 ----------------------------> | Conv2d | -> | cat | -> | TConv2d | -> dense
+ +--------+ +-----+ +---------+
+
+ Where ``y0`` has the smallest resolution, ``y1`` has a resolution double of
+ ``y0`` etc.
+
+ There are two different sets of filters:
+
+ * The TConvs filters actually perform the x2 upsampling. They are
+ referred to as upsampling filters. Implemented using
+ ``UpsamplingSeparableSymmetricConvTranspose2d``.
+
+ * The Convs filters pre-process the signal prior to concatenation. They
+ are referred to as pre-concatenation filters. Implemented using
+ ``UpsamplingSeparableSymmetricConv2d``.
+
+ Kernel sizes for the upsampling and pre-concatenation filters are modified
+ through the ``--ups_k_size`` and ``--ups_preconcat_k_size`` arguments.
+
+ Each upsampling filter and each pre-concatenation filter is different. They
+ are all separable and symmetrical.
- The kernel of the ``UpsamplingConvTranspose2d`` depending on the value
- of the flag ``static_upsampling_kernel``. In either case, the kernel
- initialization is based on well-known bilinear or bicubic kernel
- depending on the requested ``upsampling_kernel_size``:
+ Upsampling convolutions are initialized with a bilinear or bicubic kernel
+ depending on the required requested ``ups_k_size``:
- * If ``upsampling_kernel_size >= 4 and upsampling_kernel_size < 8``, a
+ * If ``ups_k_size >= 4 and ups_k_size < 8``, a
bilinear kernel (with zero padding if necessary) is used an
initialization.
- * If ``upsampling_kernel_size >= 8``, a bicubic kernel (with zero padding if
+ * If ``ups_k_size >= 8``, a bicubic kernel (with zero padding if
necessary) is used an initialization.
- .. warning::
+ Pre-concatenation convolutions are initialized with a Dirac kernel.
- The ``upsampling_kernel_size`` must be at least 4 and a multiple of 2.
- """
+ .. warning::
+
+ * The ``ups_k_size`` must be at least 4 and a multiple of 2.
- def __init__(self, upsampling_kernel_size: int, static_upsampling_kernel: bool):
+ * The ``ups_preconcat_k_size`` must be odd.
+ """
+ def __init__(
+ self,
+ ups_k_size: int,
+ ups_preconcat_k_size: int,
+ n_ups_kernel: int,
+ n_ups_preconcat_kernel: int,
+ ):
"""
Args:
- upsampling_kernel_size: Upsampling kernel size. Should be bigger or
- equal to 4 and a multiple of two.
- static_upsampling_kernel: If true, don't learn the upsampling
- kernel.
+ ups_k_size: Upsampling (TransposedConv) kernel size. Should be
+ even and >= 4.
+ ups_preconcat_k_size: Pre-concatenation kernel size. Should be odd.
+ n_ups_kernel: Number of different upsampling kernels. Usually it is
+ set to the number of latent - 1 (because the full resolution
+ latent is not upsampled). But this can also be set to one to
+ share the same kernel across all variables.
+ n_ups_preconcat_kernel: Number of different pre-concatenation
+ filters. Usually it is set to the number of latent - 1 (because
+ the smallest resolution is not filtered prior to concat).
+ But this can also be set to one to share the same kernel across
+ all variables.
"""
super().__init__()
- self.conv_transpose2d = UpsamplingConvTranspose2d(
- upsampling_kernel_size, static_upsampling_kernel
+ # number of kernels for the lower and higher branches
+ self.n_ups_kernel = n_ups_kernel
+ self.n_ups_preconcat_kernel = n_ups_preconcat_kernel
+
+ # Upsampling kernels = transpose conv2d
+ self.conv_transpose2ds = nn.ModuleList(
+ [
+ UpsamplingSeparableSymmetricConvTranspose2d(ups_k_size)
+ for _ in range(n_ups_kernel)
+ ]
+ )
+
+ # Pre concatenation filters = conv2d
+ self.conv2ds = nn.ModuleList(
+ [
+ UpsamplingSeparableSymmetricConv2d(ups_preconcat_k_size)
+ for _ in range(self.n_ups_preconcat_kernel)
+ ]
)
def forward(self, decoder_side_latent: List[Tensor]) -> Tensor:
@@ -231,14 +482,19 @@ def forward(self, decoder_side_latent: List[Tensor]) -> Tensor:
# so that the same convolution is applied independently on the batch dimension.
latent_reversed = list(reversed(decoder_side_latent))
upsampled_latent = latent_reversed[0] # start from smallest
- for target_tensor in latent_reversed[1:]:
+
+ for idx, target_tensor in enumerate(latent_reversed[1:]):
# Our goal is to upsample to the same resolution than
x = rearrange(upsampled_latent, "b c h w -> (b c) 1 h w")
- x = self.conv_transpose2d(x)
+ x = self.conv_transpose2ds[idx % self.n_ups_kernel](x)
+
x = rearrange(x, "(b c) 1 h w -> b c h w", b=upsampled_latent.shape[0])
# Crop to comply with higher resolution feature maps size before concatenation
x = x[:, :, : target_tensor.shape[-2], : target_tensor.shape[-1]]
- upsampled_latent = torch.cat((target_tensor, x), dim=1)
+
+ high_branch = self.conv2ds[idx % self.n_ups_preconcat_kernel](target_tensor)
+ upsampled_latent = torch.cat((high_branch, x), dim=1)
+
return upsampled_latent
def get_param(self) -> OrderedDict[str, Tensor]:
@@ -260,4 +516,7 @@ def set_param(self, param: OrderedDict[str, Tensor]):
def reinitialize_parameters(self) -> None:
"""Re-initialize **in place** the parameters of the upsampling."""
- self.conv_transpose2d.initialize_parameters()
+ for i in range(len(self.conv_transpose2ds)):
+ self.conv_transpose2d[i].initialize_parameters()
+ for i in range(len(self.conv2ds)):
+ self.conv2ds[i].initialize_parameters()
diff --git a/coolchic/enc/component/frame.py b/coolchic/enc/component/frame.py
index 9b69c3dd..da3a77f4 100644
--- a/coolchic/enc/component/frame.py
+++ b/coolchic/enc/component/frame.py
@@ -7,7 +7,7 @@
# Authors: see CONTRIBUTORS.md
-"""A frame encoder is composed of a CoolChicEncoder and a InterCodingModule."""
+"""A frame encoder is composed of a CoolChicEncoder and an InterCodingModule."""
import typing
from dataclasses import dataclass, field
@@ -24,16 +24,11 @@
POSSIBLE_QUANTIZER_TYPE,
)
from enc.component.intercoding import InterCodingModule
-from torch import Tensor, nn
-from enc.utils.codingstructure import (
- FRAME_DATA_TYPE,
- FRAME_TYPE,
- POSSIBLE_BITDEPTH,
- DictTensorYUV,
- convert_444_to_420,
-)
+from enc.io.format.data_type import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH
+from enc.io.format.yuv import DictTensorYUV, convert_444_to_420, yuv_dict_clamp
+from enc.utils.codingstructure import FRAME_TYPE
from enc.utils.misc import POSSIBLE_DEVICE
-from enc.utils.yuv import yuv_dict_clamp
+from torch import Tensor, nn
@dataclass
@@ -274,15 +269,15 @@ def save(self) -> BytesIO:
}
if self.coolchic_encoder.full_precision_param is not None:
- data_to_save["coolchic_full_precision_param"] = self.coolchic_encoder.full_precision_param
+ data_to_save["coolchic_full_precision_param"] = (
+ self.coolchic_encoder.full_precision_param
+ )
torch.save(data_to_save, buffer)
- # for k, v in self.coolchic_encoder.get_param().items():
- # print(f"{k:>30}: {v.abs().sum().item()}")
-
return buffer
+
def load_frame_encoder(raw_bytes: BytesIO) -> FrameEncoder:
"""From already loaded raw bytes, load & return a CoolChicEncoder
@@ -295,7 +290,7 @@ def load_frame_encoder(raw_bytes: BytesIO) -> FrameEncoder:
"""
# Reset the stream position to the beginning of the BytesIO object & load it
raw_bytes.seek(0)
- loaded_data = torch.load(raw_bytes, map_location="cpu")
+ loaded_data = torch.load(raw_bytes, map_location="cpu", weights_only=False)
# Create a frame encoder from the stored parameters
frame_encoder = FrameEncoder(
@@ -311,9 +306,13 @@ def load_frame_encoder(raw_bytes: BytesIO) -> FrameEncoder:
# Check if coolchic_nn_expgol_cnt is present in loaded data for backward
# compatibility. Not meant to stay very long.
if "coolchic_nn_expgol_cnt" in loaded_data:
- frame_encoder.coolchic_encoder.nn_expgol_cnt = loaded_data["coolchic_nn_expgol_cnt"]
+ frame_encoder.coolchic_encoder.nn_expgol_cnt = loaded_data[
+ "coolchic_nn_expgol_cnt"
+ ]
if "coolchic_full_precision_param" in loaded_data:
- frame_encoder.coolchic_encoder.full_precision_param = loaded_data["coolchic_full_precision_param"]
+ frame_encoder.coolchic_encoder.full_precision_param = loaded_data[
+ "coolchic_full_precision_param"
+ ]
return frame_encoder
diff --git a/coolchic/enc/component/video.py b/coolchic/enc/component/video.py
index 7affdf96..6588ddc9 100644
--- a/coolchic/enc/component/video.py
+++ b/coolchic/enc/component/video.py
@@ -23,7 +23,7 @@
from enc.training.warmup import warmup
from enc.utils.codingstructure import CodingStructure, Frame, FrameData
from enc.utils.misc import POSSIBLE_DEVICE, TrainingExitCode, is_job_over, mem_info
-from enc.utils.yuv import load_frame_data_from_file
+from enc.io.io import load_frame_data_from_file
class VideoEncoder():
@@ -168,6 +168,8 @@ def encode(
+ "-" * 80
)
+ print("\n" + frame.data.to_string() + "\n")
+
# ----- Set the parameters for the frame
frame_encoder_manager = copy.deepcopy(
self.shared_frame_encoder_manager
@@ -226,6 +228,8 @@ def encode(
f_out.write(str(list_candidates[0].coolchic_encoder) + "\n\n")
f_out.write(list_candidates[0].coolchic_encoder.str_complexity() + "\n")
+ print(list_candidates[0].coolchic_encoder.pretty_string() + "\n\n")
+
# Use warm-up to find the best initialization among the list
# of candidates parameters.
frame_encoder = warmup(
@@ -502,7 +506,7 @@ def load_video_encoder(load_path: str) -> VideoEncoder:
"""
print(f"Loading a video encoder from {load_path}")
- raw_data = torch.load(load_path, map_location="cpu")
+ raw_data = torch.load(load_path, map_location="cpu", weights_only=False)
# Calling the VideoEncoder constructor automatically reload the
# original frames.
diff --git a/coolchic/enc/io/format/data_type.py b/coolchic/enc/io/format/data_type.py
new file mode 100644
index 00000000..32c36c3b
--- /dev/null
+++ b/coolchic/enc/io/format/data_type.py
@@ -0,0 +1,14 @@
+# Software Name: Cool-Chic
+# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
+# SPDX-License-Identifier: BSD 3-Clause "New"
+#
+# This software is distributed under the BSD-3-Clause license.
+#
+# Authors: see CONTRIBUTORS.md
+
+
+from typing import Literal
+
+
+FRAME_DATA_TYPE = Literal["rgb", "yuv420", "yuv444"]
+POSSIBLE_BITDEPTH = Literal[8, 9, 10, 11, 12, 13, 14, 15, 16]
diff --git a/coolchic/enc/io/format/png.py b/coolchic/enc/io/format/png.py
new file mode 100644
index 00000000..d7332046
--- /dev/null
+++ b/coolchic/enc/io/format/png.py
@@ -0,0 +1,50 @@
+# Software Name: Cool-Chic
+# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
+# SPDX-License-Identifier: BSD 3-Clause "New"
+#
+# This software is distributed under the BSD-3-Clause license.
+#
+# Authors: see CONTRIBUTORS.md
+
+
+import os
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from enc.io.format.data_type import POSSIBLE_BITDEPTH
+from PIL import Image
+from torch import Tensor
+from torchvision.transforms.functional import to_pil_image, to_tensor
+
+
+def read_png(file_path: str) -> Tuple[Tensor, POSSIBLE_BITDEPTH]:
+ """Read a PNG file
+
+ Args:
+ file_path: Path of the png file to read.
+
+ Returns:
+ Image data [1, 3, H, W] in [0., 1.] and its bitdepth.
+ """
+ assert os.path.isfile(file_path), f"No file found at {file_path}."
+
+ data = to_tensor(Image.open(file_path))
+ data = rearrange(data, "c h w -> 1 c h w")
+
+ # Bitdepth is always 8 when we read PNG through PIL?
+ bitdepth = 8
+
+ return data, bitdepth
+
+
+@torch.no_grad()
+def write_png(data: Tensor, file_path: str) -> None:
+ """Save an image x into a PNG file.
+
+ Args:
+ x: Image to be saved
+ file_path: Where to save the PNG files
+ """
+ data = rearrange(data, "1 c h w -> c h w", c=3)
+ to_pil_image(data).save(file_path)
diff --git a/coolchic/enc/io/format/ppm.py b/coolchic/enc/io/format/ppm.py
new file mode 100644
index 00000000..88e236e7
--- /dev/null
+++ b/coolchic/enc/io/format/ppm.py
@@ -0,0 +1,205 @@
+# Software Name: Cool-Chic
+# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
+# SPDX-License-Identifier: BSD 3-Clause "New"
+#
+# This software is distributed under the BSD-3-Clause license.
+#
+# Authors: see CONTRIBUTORS.md
+
+
+import math
+import os
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch import Tensor
+
+from enc.io.format.data_type import POSSIBLE_BITDEPTH
+
+
+def _skip_one_byte(data: bytearray) -> bytearray:
+ """Skip one byte in a byte array and return the array
+ with its first byte removed.
+
+ Args:
+ data: Input byte array. Length is N
+
+ Returns:
+ bytearray: Output byte array. Length is N - 1
+ """
+ return data[1:]
+
+
+def _read_int_until_blank(data: bytearray) -> Tuple[int, bytearray]:
+ """Parse an ASCII int until running into one of the space characters.
+ Also return the input byte array where the bytes corresponding to
+ both the int and the space character are skipped.
+
+ 132\ncds# --> Return the value 123 and a byte array containing cds#
+
+
+ As defined by the ANSI standard C isspace(s) returns True for
+ Horizontal tab (HT), line feed (LF), vertical tabulation (VT),
+ form feed (FF), carriage return (CR) and white space.
+
+ Note: this function may fail if the bytes collected up to the space
+ character do not represent ascii numbers.
+
+ Args:
+ data: Input data.
+
+ Returns:
+ Parsed int + input data where we've removed the bytes corresponding to
+ the int value and the space character.
+ """
+ # As defined by the ANSI standard C isspace(s)
+ # ASCII code for Horizontal tab (HT), line feed (LF),
+ # vertical tabulation (VT), form feed (FF), carriage return (CR)
+ # and white space
+ _BLANKS_ASCII = [9, 10, 11, 12, 13, 32]
+
+ ptr_end = 0
+ while data[ptr_end] not in _BLANKS_ASCII:
+ ptr_end += 1
+
+ value = int(data[:ptr_end].decode("utf-8"))
+ data = data[ptr_end:]
+ return value, data
+
+
+def _16bits_byte_swap(data: Tensor) -> Tensor:
+ """Invert the bytes composing a 2-byte value. The actual data type
+ of the tensor is not important but it must contains value in
+ [0, 2 ** 16 - 1]
+
+ For instance:
+
+ 1111 1111 0000 0010 ==> 0000 0010 1111 1111
+ \______/ \______/ \______/ \______/
+ MSB LSB LSB MSB
+
+ Args:
+ data: Tensor to be swapped.
+
+ Returns:
+ Swapped tensor.
+ """
+
+ msb = data // 2**8
+ lsb = data % 2**8
+ swapped_data = lsb * 2**8 + msb
+ return swapped_data
+
+
+def read_ppm(file_path: str) -> Tuple[Tensor, POSSIBLE_BITDEPTH]:
+ """Read a `PPM file `_,
+ and return a torch tensor [1, 3, H, W] containing the data.
+ The returned tensor is in [0., 1.] so its bitdepth is also returned.
+
+ .. attention::
+
+ We don't filter out comments inside PPM files...
+
+ Args:
+ file_path: Path of the ppm file to read.
+
+ Returns:
+ Image data [1, 3, H, W] in [0., 1.] and its bitdepth.
+ """
+ assert os.path.isfile(file_path), f"No file found at {file_path}."
+
+ data = open(file_path, "rb").read()
+ magic_number = data[:2].decode("utf-8")
+ data = data[2:]
+ assert magic_number == "P6", (
+ "Invalid file format. PPM file should start with P6. " f"Found {magic_number}."
+ )
+
+ # Parse the header
+ width, data = _read_int_until_blank(_skip_one_byte(data))
+ height, data = _read_int_until_blank(_skip_one_byte(data))
+ max_val, data = _read_int_until_blank(_skip_one_byte(data))
+ data = _skip_one_byte(data)
+
+ n_bytes_per_val = 1 if max_val <= 255 else 2
+ bitdepth = int(math.log2(max_val + 1))
+
+ raw_value = torch.from_numpy(
+ np.frombuffer(
+ data,
+ count=3 * width * height,
+ dtype=np.uint8 if n_bytes_per_val == 1 else np.uint16,
+ )
+ )
+
+ # Re-arrange the value from R1 B1 G1 R2 B2 G2 to an usual [B, C, H, W] array
+ img = torch.empty(
+ (1, 3, height, width),
+ dtype=torch.float32,
+ )
+
+ for i in range(3):
+ img[:, i, :, :] = raw_value[i::3].view(1, height, width)
+
+ # In a PPM file 2-byte value (e.g. 257) is represented as
+ # 1111 1111 0000 0010
+ # \______/ \______/
+ # MSB LSB
+ # We want to invert these two bytes here to have an usual binary value
+ # 0000 0010 1111 1111
+ if n_bytes_per_val == 2:
+ img = _16bits_byte_swap(img)
+
+ # Normalize in [0. 1.]
+ img = img / (2**bitdepth - 1)
+
+ return img, bitdepth
+
+
+@torch.no_grad()
+def write_ppm(
+ data: Tensor, bitdepth: POSSIBLE_BITDEPTH, file_path: str, norm: bool = True
+) -> None:
+ """Save an image x into a PPM file.
+
+ Args:
+ data: Image to be saved
+ bitdepth: Bitdepth, should be in
+ ``[8, 9, 10, 11, 12, 13, 14, 15, 16]``.
+ file_path: Where to save the PPM files
+ bitdepth: Bitdepth of the file. Defaults to 8.
+ norm: True to multiply the data by 2 ** bitdepth - 1. Defaults to True.
+ """
+ # Remove all first dimensions of size 1
+ c, h, w = data.size()[-3:]
+ data = data.view((c, h, w))
+
+ max_val = 2**data.bitdepth - 1
+ n_bytes_per_val = 1 if max_val <= 255 else 2
+ header = f"P6\n{w} {h}\n{max_val}\n"
+
+ if norm:
+ data = torch.round(data * (2**bitdepth - 1))
+
+ # In a PPM file 2-byte value (e.g. 257) is represented as
+ # 1111 1111 0000 0010
+ # \______/ \______/
+ # MSB LSB
+ # We want to invert these two bytes here to have an usual binary value
+ # 0000 0010 1111 1111
+ if n_bytes_per_val == 2:
+ data = _16bits_byte_swap(data)
+
+ # Format data as expected by the PPM file.
+ flat_data = torch.empty(
+ (c * h * w), dtype=torch.uint8 if max_val <= 255 else torch.uint16
+ )
+ for i in range(c):
+ flat_data[i::3] = data[i, :, :].flatten()
+
+ # Write once the header as a string then the data as binary bytes
+ with open(file_path, "w") as f_out:
+ f_out.write(header)
+ with open(file_path, "ab") as f_out:
+ f_out.write(np.memmap.tobytes(flat_data.numpy()))
diff --git a/coolchic/enc/io/format/yuv.py b/coolchic/enc/io/format/yuv.py
new file mode 100644
index 00000000..15d17ad7
--- /dev/null
+++ b/coolchic/enc/io/format/yuv.py
@@ -0,0 +1,311 @@
+# Software Name: Cool-Chic
+# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
+# SPDX-License-Identifier: BSD 3-Clause "New"
+#
+# This software is distributed under the BSD-3-Clause license.
+#
+# Authors: see CONTRIBUTORS.md
+
+
+import os
+from typing import TypedDict, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from enc.io.format.data_type import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH
+from enc.utils.misc import POSSIBLE_DEVICE
+from torch import Tensor
+
+
+class DictTensorYUV(TypedDict):
+ """``TypedDict`` representing a YUV420 frame..
+
+ .. hint::
+
+ ``torch.jit`` requires I/O of modules to be either ``Tensor``, ``List``
+ or ``Dict``. So we don't use a python dataclass here and rely on
+ ``TypedDict`` instead.
+
+ Args:
+ y (Tensor): Tensor with shape :math:`([B, 1, H, W])`.
+ u (Tensor): Tensor with shape :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`.
+ v (Tensor): Tensor with shape :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`.
+ """
+
+ y: Tensor
+ u: Tensor
+ v: Tensor
+
+
+def read_yuv(
+ file_path: str,
+ frame_idx: int,
+ frame_data_type: FRAME_DATA_TYPE,
+ bit_depth: POSSIBLE_BITDEPTH,
+) -> Union[DictTensorYUV, Tensor]:
+ """From a file_path /a/b/c.yuv, read the desired frame_index
+ and return a dictionary of tensor containing the YUV values:
+
+ .. code:: none
+
+ {
+ 'Y': [1, 1, H, W],
+ 'U': [1, 1, H / S, W / S],
+ 'V': [1, 1, H / S, W / S],
+ }
+
+ ``S`` is either 1 (444 sampling) or 2 (420). The YUV values are in [0., 1.]
+
+ Args:
+ file_path: Absolute path of the video to load
+ frame_idx: Index of the frame to load, starting at 0.
+ frame_data_type: chroma sampling (420,444)
+ bit depth: Number of bits per component (8 or 10 bits).
+
+ Returns:
+ For 420, return a dict of tensors with YUV values of shape [1, 1, H, W].
+ For 444 return a [1, 3, H, W] tensor.
+ """
+
+ # Parse height and width from the file_path
+ w, h = [
+ int(tmp_str)
+ for tmp_str in os.path.basename(file_path).split(".")[0].split("_")[1].split("x")
+ ]
+
+ if frame_data_type == "yuv420":
+ w_uv, h_uv = [int(x / 2) for x in [w, h]]
+ else:
+ w_uv, h_uv = w, h
+
+ # Switch between 8 bit file and 10 bit
+ byte_per_value = 1 if bit_depth == 8 else 2
+
+ # We only handle YUV420 for now
+ n_val_y = h * w
+ n_val_uv = h_uv * w_uv
+ n_val_per_frame = n_val_y + 2 * n_val_uv
+
+ n_bytes_y = n_val_y * byte_per_value
+ n_bytes_uv = n_val_uv * byte_per_value
+ n_bytes_per_frame = n_bytes_y + 2 * n_bytes_uv
+
+ # Read the required frame and put it in a 1d tensor
+ raw_video = torch.tensor(
+ np.memmap(
+ file_path,
+ mode="r",
+ shape=n_val_per_frame,
+ offset=n_bytes_per_frame * frame_idx,
+ dtype=np.uint16 if bit_depth == 10 else np.uint8,
+ ).astype(np.float32)
+ )
+
+ # Read the different values from raw video and store them inside y, u and v
+ ptr = 0
+ y = raw_video[ptr : ptr + n_val_y].view(1, 1, h, w)
+ ptr += n_val_y
+ u = raw_video[ptr : ptr + n_val_uv].view(1, 1, h_uv, w_uv)
+ ptr += n_val_uv
+ v = raw_video[ptr : ptr + n_val_uv].view(1, 1, h_uv, w_uv)
+
+ # PyTorch expect data in [0., 1.]; normalize by either 255 or 1023
+ norm_factor = 2**bit_depth - 1
+
+ if frame_data_type == "yuv420":
+ video = DictTensorYUV(y=y / norm_factor, u=u / norm_factor, v=v / norm_factor)
+ else:
+ video = torch.cat([y, u, v], dim=1) / norm_factor
+
+ return video
+
+
+@torch.no_grad()
+def write_yuv(
+ data: Union[Tensor, DictTensorYUV],
+ bitdepth: POSSIBLE_BITDEPTH,
+ frame_data_type: FRAME_DATA_TYPE,
+ file_path: str,
+ norm: bool = True,
+) -> None:
+ """Store a YUV frame as a YUV file named file_path. They are appended to the
+ end of the file_path If norm is True: the video data is expected to be in
+ [0., 1.] so we multiply it by 255. Otherwise we let it as is.
+
+ Args:
+ data: Data to save
+ bitdepth: Bitdepth, should be in``[8, 9, 10, 11, 12, 13, 14, 15, 16]``.
+ frame_data_type: Data type, either ``"yuv420"`` or ``"yuv444"``.
+ file_path: Absolute path of the file where the YUV is saved.
+ norm: True to multiply the data by 2 ** bitdepth - 1.
+ Defaults to True.
+ """
+ assert frame_data_type in ["yuv420", "yuv444"], (
+ f"Found incorrect datatype in write_yuv() function: {frame_data_type}. "
+ 'Data type should be "yuv420" or "yuv444".'
+ )
+
+ if frame_data_type == "yuv420":
+ raw_data = torch.cat([channels.flatten() for _, channels in data.items()])
+ else:
+ raw_data = data.flatten()
+
+ if norm:
+ raw_data = raw_data * (2**bitdepth - 1)
+
+ dtype = np.uint16 if bitdepth == 10 else np.uint8
+
+ # Round the values and cast them to uint8 or uint16 tensor
+ raw_data = torch.round(raw_data).cpu().numpy().astype(dtype)
+
+ # # We need to add a p avec yuv444 otherwise YUView thinks its "YUV444 8-bit packed"
+ # file_path = (
+ # f"{file_path}_{w}x{h}_{DUMMY_FRAMERATE}fps_{frame_data_type}p_{bitdepth}b.yuv"
+ # )
+
+ # Write this to the desired file_path
+ np.memmap.tofile(raw_data, file_path)
+
+
+def rgb2yuv(rgb: Tensor) -> Tensor:
+ """Convert a 4D RGB tensor [1, 3, H, W] into a 4D YUV444 tensor [1, 3, H, W].
+ The RGB and YUV values are in the range [0, 255]
+
+ Args:
+ rgb: 4D RGB tensor to convert in [0. 255.]
+
+ Returns:
+ The resulting YUV444 tensor in [0. 255.]
+ """
+ assert (
+ len(rgb.size()) == 4
+ ), f"rgb2yuv input must be a 4D tensor [B, 3, H, W]. Data size: {rgb.size()}"
+ assert (
+ rgb.size()[1] == 3
+ ), f"rgb2yuv input must have 3 channels. Data size: {rgb.size()}"
+
+ # Split the [1, 3, H, W] into 3 [1, 1, H, W] tensors
+ r, g, b = rgb.split(1, dim=1)
+
+ # Compute the different channels
+ y = torch.round(0.299 * r + 0.587 * g + 0.114 * b)
+ u = torch.round(-0.1687 * r - 0.3313 * g + 0.5 * b + +128)
+ v = torch.round(0.5 * r - 0.4187 * g - 0.0813 * b + 128)
+
+ # Concatenate them into the resulting yuv 4D tensor.
+ yuv = torch.cat((y, u, v), dim=1)
+ return yuv
+
+
+def yuv2rgb(yuv: Tensor):
+ """Convert a 4D YUV tensor [1, 3, H, W] into a 4D RGB tensor [1, 3, H, W].
+ The RGB and YUV values are in the range [0, 255]
+
+ Args:
+ rgb: 4D YUV444 tensor to convert in [0. 255.]
+
+ Returns:
+ The resulting RGB tensor in [0. 255.]
+ """
+ assert (
+ len(yuv.size()) == 4
+ ), f"yuv2rgb input must be a 4D tensor [B, 3, H, W]. Data size: {yuv.size()}"
+ assert (
+ yuv.size()[1] == 3
+ ), f"yuv2rgb input must have 3 channels. Data size: {yuv.size()}"
+
+ y, u, v = yuv.split(1, dim=1)
+ r = (
+ 1.0 * y
+ + -0.000007154783816076815 * u
+ + 1.4019975662231445 * v
+ - 179.45477266423404
+ )
+ g = 1.0 * y + -0.3441331386566162 * u + -0.7141380310058594 * v + 135.45870971679688
+ b = (
+ 1.0 * y
+ + 1.7720025777816772 * u
+ + 0.00001542569043522235 * v
+ - 226.8183044444304
+ )
+ rgb = torch.cat((r, g, b), dim=1)
+ return rgb
+
+
+def yuv_dict_clamp(yuv: DictTensorYUV, min_val: float, max_val: float) -> DictTensorYUV:
+ """Clamp the y, u & v tensor.
+
+ Args:
+ yuv: The data to clamp
+ min_val: Minimum value for the clamp
+ max_val: Maximum value for the clamp
+
+ Returns:
+ The clamped data
+
+ """
+ clamped_yuv = DictTensorYUV(
+ y=yuv.get("y").clamp(min_val, max_val),
+ u=yuv.get("u").clamp(min_val, max_val),
+ v=yuv.get("v").clamp(min_val, max_val),
+ )
+ return clamped_yuv
+
+
+def yuv_dict_to_device(yuv: DictTensorYUV, device: POSSIBLE_DEVICE) -> DictTensorYUV:
+ """Send a ``DictTensor`` to a device.
+
+ Args:
+ yuv: Data to be sent to a device.
+ device: The requested device
+
+ Returns:
+ Data on the appropriate device.
+ """
+ return DictTensorYUV(
+ y=yuv.get("y").to(device), u=yuv.get("u").to(device), v=yuv.get("v").to(device)
+ )
+
+
+def convert_444_to_420(yuv444: Tensor) -> DictTensorYUV:
+ """From a 4D YUV 444 tensor :math:`(B, 3, H, W)`, return a
+ ``DictTensorYUV``. The U and V tensors are down sampled using a nearest
+ neighbor downsampling.
+
+ Args:
+ yuv444: YUV444 data :math:`(B, 3, H, W)`
+
+ Returns:
+ YUV420 dictionary of 4D tensors
+ """
+ assert yuv444.dim() == 4, f"Number of dimension should be 5, found {yuv444.dim()}"
+
+ b, c, h, w = yuv444.size()
+ assert c == 3, f"Number of channel should be 3, found {c}"
+
+ # No need to downsample y channel but it should remain a 5D tensor
+ y = yuv444[:, 0, :, :].view(b, 1, h, w)
+
+ # Downsample U and V channels together
+ uv = F.interpolate(yuv444[:, 1:3, :, :], scale_factor=(0.5, 0.5), mode="nearest")
+ u, v = uv.split(1, dim=1)
+
+ yuv420 = DictTensorYUV(y=y, u=u, v=v)
+ return yuv420
+
+
+def convert_420_to_444(yuv420: DictTensorYUV) -> Tensor:
+ """Convert a DictTensorYUV to a 4D tensor:math:`(B, 3, H, W)`.
+ The U and V tensors are up sampled using a nearest neighbor upsampling
+
+ Args:
+ yuv420: YUV420 dictionary of 4D tensor
+
+ Returns:
+ YUV444 Tensor :math:`(B, 3, H, W)`
+ """
+ u = F.interpolate(yuv420.get("u"), scale_factor=(2, 2))
+ v = F.interpolate(yuv420.get("v"), scale_factor=(2, 2))
+ yuv444 = torch.cat((yuv420.get("y"), u, v), dim=1)
+ return yuv444
diff --git a/coolchic/enc/io/io.py b/coolchic/enc/io/io.py
new file mode 100644
index 00000000..c46ff0b0
--- /dev/null
+++ b/coolchic/enc/io/io.py
@@ -0,0 +1,39 @@
+from enc.utils.codingstructure import FrameData
+from enc.io.format.data_type import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH
+from enc.io.format.ppm import read_ppm
+from enc.io.format.yuv import read_yuv
+from enc.io.format.png import read_png
+
+
+def load_frame_data_from_file(file_path: str, idx_display_order: int) -> FrameData:
+ """Load the idx_display_order-th frame from a .yuv file or .png file. For the latter,
+ idx_display_order must be equal to 0 as there is only one frame in a png.
+
+ Args:
+ file_path (str): Absolute path of the file from which the frame is loaded.
+ idx_display_order (int): Index of the frame in display order
+
+ Returns:
+ FrameData: The loaded frame, wrapped as a FrameData object.
+ """
+ POSSIBLE_EXT = [".yuv", ".png", ".ppm"]
+ assert file_path[-4:] in POSSIBLE_EXT, (
+ "The function load_frame_data_from_file() expects a file ending with "
+ f"{POSSIBLE_EXT}. Found {file_path}"
+ )
+
+ if file_path.endswith(".yuv"):
+ # ! We only consider yuv420 and 444 planar
+ bitdepth: POSSIBLE_BITDEPTH = 8 if "_8b" in file_path else 10
+ frame_data_type: FRAME_DATA_TYPE = "yuv420" if "420" in file_path else "yuv444"
+ data = read_yuv(file_path, idx_display_order, frame_data_type, bitdepth)
+
+ elif file_path.endswith(".png"):
+ frame_data_type: FRAME_DATA_TYPE = "rgb"
+ data, bitdepth = read_png(file_path)
+
+ elif file_path.endswith(".ppm"):
+ frame_data_type: FRAME_DATA_TYPE = "rgb"
+ data, bitdepth = read_ppm(file_path)
+
+ return FrameData(bitdepth, frame_data_type, data)
diff --git a/coolchic/enc/utils/presets.py b/coolchic/enc/training/presets.py
similarity index 99%
rename from coolchic/enc/utils/presets.py
rename to coolchic/enc/training/presets.py
index 653f3ea5..8e818eb5 100644
--- a/coolchic/enc/utils/presets.py
+++ b/coolchic/enc/training/presets.py
@@ -398,3 +398,4 @@ def __init__(self, start_lr: float = 1e-2, n_itr_per_phase: int = 100000):
"c3x": PresetC3x,
"debug": PresetDebug,
}
+
diff --git a/coolchic/enc/training/quantizemodel.py b/coolchic/enc/training/quantizemodel.py
index a009cc99..6aa58a9e 100644
--- a/coolchic/enc/training/quantizemodel.py
+++ b/coolchic/enc/training/quantizemodel.py
@@ -51,9 +51,9 @@ def _quantize_parameters(
sent_param = torch.round(v / current_q_step)
if sent_param.abs().max() > MAX_AC_MAX_VAL:
- print(
- f"Sent param {k} exceed MAX_AC_MAX_VAL! Q step {current_q_step} too small."
- )
+ #print(
+ # f"Sent param {k} exceed MAX_AC_MAX_VAL! Q step {current_q_step} too small."
+ #)
return None
q_param[k] = sent_param * current_q_step
@@ -187,7 +187,7 @@ def quantize_model(
# to obtain the sent latent.
current_sent_param = (parameter_value / current_q_step.get(weight_or_bias)).view(-1)
- if parameter_name.endswith(weight_or_bias):
+ if weight_or_bias in parameter_name:
sent_param.append(current_sent_param)
# Integer, sent parameters
diff --git a/coolchic/enc/training/train.py b/coolchic/enc/training/train.py
index ce1b0345..5c06b6d7 100644
--- a/coolchic/enc/training/train.py
+++ b/coolchic/enc/training/train.py
@@ -23,7 +23,7 @@
from enc.training.loss import loss_function
from enc.training.test import test
from enc.utils.codingstructure import Frame
-from enc.utils.presets import MODULE_TO_OPTIMIZE
+from enc.training.presets import MODULE_TO_OPTIMIZE
# Custom scheduling function for the soft rounding temperature and the noise parameter
@@ -306,7 +306,7 @@ def train(
"patience": (patience - cnt + cnt_record) // frequency_validation,
"q_type": f"{quantizer_type:12s}",
"sr_temp": f"{cur_softround_temperature:.5f}",
- "n_type": f"{quantizer_noise_type:20s}",
+ "n_type": f"{quantizer_noise_type:12s}",
"noise": f"{cur_noise_parameter:.2f}",
"record": log_new_record,
}
diff --git a/coolchic/enc/utils/codingstructure.py b/coolchic/enc/utils/codingstructure.py
index 81d983a6..d649d923 100644
--- a/coolchic/enc/utils/codingstructure.py
+++ b/coolchic/enc/utils/codingstructure.py
@@ -10,14 +10,16 @@
import math
from dataclasses import dataclass, field
-from typing import List, Literal, Optional, Tuple, TypedDict, Union
+from typing import List, Literal, Optional, Tuple, Union
-import torch
-import torch.nn.functional as F
from torch import Tensor
from enc.utils.misc import POSSIBLE_DEVICE
+from enc.io.format.data_type import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH
+from enc.io.format.yuv import DictTensorYUV, convert_420_to_444, yuv_dict_to_device
+
+
# The different frame types:
# - I frames have no reference (intra)
# - P frames have 1 single (past) reference
@@ -44,90 +46,6 @@
# between two I-frames. First GOP P-period is 1, while second P period is 4
# which is the distance of the P-frame prediction.
#
-FRAME_DATA_TYPE = Literal["rgb", "yuv420", "yuv444"]
-POSSIBLE_BITDEPTH = Literal[8, 10]
-
-
-class DictTensorYUV(TypedDict):
- """``TypedDict`` representing a YUV420 frame..
-
- .. hint::
-
- ``torch.jit`` requires I/O of modules to be either ``Tensor``, ``List``
- or ``Dict``. So we don't use a python dataclass here and rely on
- ``TypedDict`` instead.
-
- Args:
- y (Tensor): :math:`([B, 1, H, W])`.
- u (Tensor): :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`.
- v (Tensor): :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`.
- """
-
- y: Tensor
- u: Tensor
- v: Tensor
-
-
-def yuv_dict_to_device(yuv: DictTensorYUV, device: POSSIBLE_DEVICE) -> DictTensorYUV:
- """Send a ``DictTensor`` to a device.
-
- Args:
- yuv: Data to be sent to a device.
- device: The requested device
-
- Returns:
- Data on the appropriate device.
- """
- return DictTensorYUV(
- y=yuv.get("y").to(device), u=yuv.get("u").to(device), v=yuv.get("v").to(device)
- )
-
-
-# ============================== YUV upsampling ============================= #
-def convert_444_to_420(yuv444: Tensor) -> DictTensorYUV:
- """From a 4D YUV 444 tensor :math:`(B, 3, H, W)`, return a
- ``DictTensorYUV``. The U and V tensors are down sampled using a nearest
- neighbor downsampling.
-
- Args:
- yuv444: YUV444 data :math:`(B, 3, H, W)`
-
- Returns:
- YUV420 dictionary of 4D tensors
- """
- assert yuv444.dim() == 4, f"Number of dimension should be 5, found {yuv444.dim()}"
-
- b, c, h, w = yuv444.size()
- assert c == 3, f"Number of channel should be 3, found {c}"
-
- # No need to downsample y channel but it should remain a 5D tensor
- y = yuv444[:, 0, :, :].view(b, 1, h, w)
-
- # Downsample U and V channels together
- uv = F.interpolate(yuv444[:, 1:3, :, :], scale_factor=(0.5, 0.5), mode="nearest")
- u, v = uv.split(1, dim=1)
-
- yuv420 = DictTensorYUV(y=y, u=u, v=v)
- return yuv420
-
-
-def convert_420_to_444(yuv420: DictTensorYUV) -> Tensor:
- """Convert a DictTensorYUV to a 4D tensor:math:`(B, 3, H, W)`.
- The U and V tensors are up sampled using a nearest neighbor upsampling
-
- Args:
- yuv420: YUV420 dictionary of 4D tensor
-
- Returns:
- YUV444 Tensor :math:`(B, 3, H, W)`
- """
- u = F.interpolate(yuv420.get("u"), scale_factor=(2, 2))
- v = F.interpolate(yuv420.get("v"), scale_factor=(2, 2))
- yuv444 = torch.cat((yuv420.get("y"), u, v), dim=1)
- return yuv444
-
-
-# ============================== YUV upsampling ============================= #
@dataclass
@@ -136,7 +54,8 @@ class FrameData:
a few additional information about its size, bitdepth of color space.
Args:
- bitdepth (POSSIBLE_BITDEPTH): Bitdepth, either ``"8"`` or ``"10"``.
+ bitdepth (POSSIBLE_BITDEPTH): Bitdepth, should be in
+ ``[8, 9, 10, 11, 12, 13, 14, 15, 16]``.
frame_data_type (FRAME_DATA_TYPE): Data type, either ``"rgb"``,
``"yuv420"``, ``"yuv444"``.
data (Union[Tensor, DictTensorYUV]): The actual RGB or YUV data
@@ -172,6 +91,15 @@ def to_device(self, device: POSSIBLE_DEVICE) -> None:
elif self.frame_data_type == "yuv420":
self.data = yuv_dict_to_device(self.data, device)
+ def to_string(self) -> str:
+ """Pretty string describing the frame data."""
+ s = "Frame data information:\n"
+ s += "-----------------------\n"
+ s += f"{'Resolution (H, W)':<26}: {self.img_size[0]}, {self.img_size[1]}\n"
+ s += f"{'Bitdepth':<26}: {self.bitdepth}\n"
+ s += f"{'Data type':<26}: {self.frame_data_type}"
+
+ return s
@dataclass
class Frame:
diff --git a/coolchic/enc/utils/manager.py b/coolchic/enc/utils/manager.py
index 49daaac9..0190d9fc 100644
--- a/coolchic/enc/utils/manager.py
+++ b/coolchic/enc/utils/manager.py
@@ -7,7 +7,7 @@
# Authors: see CONTRIBUTORS.md
from dataclasses import dataclass, field, fields
-from enc.utils.presets import AVAILABLE_PRESETS, Preset
+from enc.training.presets import AVAILABLE_PRESETS, Preset
@dataclass
diff --git a/coolchic/enc/utils/misc.py b/coolchic/enc/utils/misc.py
index f40fd555..93c536aa 100644
--- a/coolchic/enc/utils/misc.py
+++ b/coolchic/enc/utils/misc.py
@@ -154,13 +154,13 @@ def get_q_step_from_parameter_name(
Optional[float]: The quantization step associated to the parameter.
Return None if nothing is found.
"""
- if parameter_name.endswith(".weight"):
+ if ".weight" in parameter_name:
current_q_step = q_step.get("weight")
- elif parameter_name.endswith(".bias"):
+ elif ".bias" in parameter_name:
current_q_step = q_step.get("bias")
else:
print(
- 'Parameter name should end with ".weight" or ".bias" '
+ 'Parameter name should include ".weight" or ".bias" '
f"Found: {parameter_name}"
)
current_q_step = None
@@ -195,13 +195,13 @@ def measure_expgolomb_rate(
# to obtain the sent latent.
current_sent_param = (parameter_value / current_q_step).view(-1)
- if parameter_name.endswith(".weight"):
+ if ".weight" in parameter_name:
sent_param["weight"].append(current_sent_param)
- elif parameter_name.endswith(".bias"):
+ elif ".bias" in parameter_name:
sent_param["bias"].append(current_sent_param)
else:
print(
- 'Parameter name should end with ".weight" or ".bias" '
+ 'Parameter name should include ".weight" or ".bias" '
f"Found: {parameter_name}"
)
return rate_param
diff --git a/coolchic/enc/utils/parsecli.py b/coolchic/enc/utils/parsecli.py
new file mode 100644
index 00000000..fd072121
--- /dev/null
+++ b/coolchic/enc/utils/parsecli.py
@@ -0,0 +1,174 @@
+# Software Name: Cool-Chic
+# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
+# SPDX-License-Identifier: BSD 3-Clause "New"
+#
+# This software is distributed under the BSD-3-Clause license.
+#
+# Authors: see CONTRIBUTORS.md
+
+
+import argparse
+import os
+from typing import Any, Dict, List
+
+
+# ----- Arguments related to Cool-chic parameters
+def _parse_synthesis_layers(layers_synthesis: str) -> List[str]:
+ """The layers of the synthesis are presented in as a coma-separated string.
+ This simply splits up the different substrings and return them.
+
+ Args:
+ layers_synthesis (str): Command line argument for the synthesis.
+
+ Returns:
+ List[str]: List of string where the i-th element described the i-th
+ synthesis layer
+ """
+ parsed_layer_synth = [x for x in layers_synthesis.split(",") if x != ""]
+
+ assert parsed_layer_synth, (
+ "Synthesis should have at least one layer, found nothing. \n"
+ f"--layers_synthesis={layers_synthesis} does not work!\n"
+ "Try something like 32-1-linear-relu,X-1-linear-none,"
+ "X-3-residual-relu,X-3-residual-none"
+ )
+
+ return parsed_layer_synth
+
+
+def _parse_arm_archi(arm: str) -> Dict[str, int]:
+ """The arm is described as ,.
+ Split up this string to return the value as a dict.
+
+ Args:
+ arm (str): Command line argument for the ARM.
+
+ Returns:
+ Dict[str, int]: The ARM architecture
+ """
+ assert len(arm.split(",")) == 2, f"--arm format should be X,Y." f" Found {arm}"
+
+ dim_arm, n_hidden_layers_arm = [int(x) for x in arm.split(",")]
+ arm_param = {"dim_arm": dim_arm, "n_hidden_layers_arm": n_hidden_layers_arm}
+ return arm_param
+
+
+def _parse_n_ft_per_res(n_ft_per_res: str) -> List[int]:
+ """The number of feature per resolution is a coma-separated string.
+ This simply splits up the different substrings and return them.
+
+ Args:
+ n_ft_per_res (str): Something like "1,1,1,1,1,1,1" for 7 latent grids
+ with different resolution and 1 feature each.
+
+ Returns:
+ List[int]: The i-th element is the number of features for the i-th
+ latent, i.e. the latent of a resolution (H / 2^i, W / 2^i).
+ """
+
+ n_ft_per_res = [int(x) for x in n_ft_per_res.split(",") if x != ""]
+ assert set(n_ft_per_res) == {
+ 1
+ }, f"--n_ft_per_res should only contains 1. Found {n_ft_per_res}"
+ return n_ft_per_res
+
+
+def get_coolchic_param_from_args(args: argparse.Namespace) -> Dict[str, Any]:
+
+ layers_synthesis = _parse_synthesis_layers(getattr(args, "layers_synthesis"))
+ n_ft_per_res = _parse_n_ft_per_res(getattr(args, "n_ft_per_res"))
+
+ coolchic_param = {
+ "layers_synthesis": layers_synthesis,
+ "n_ft_per_res": n_ft_per_res,
+ "ups_k_size": getattr(args, "ups_k_size"),
+ "ups_preconcat_k_size": getattr(args, "ups_preconcat_k_size"),
+ }
+
+ # Add ARM parameters
+ coolchic_param.update(_parse_arm_archi(getattr(args, "arm")))
+
+ return coolchic_param
+
+
+# ----- Arguments related to the coding structure
+def _is_image(file_path: str) -> bool:
+ """Return True is file extension is an image extension ie JPEG, PNG or PPM.
+
+ Args:
+ file_path (str): Path of the file.
+
+ Returns:
+ bool: True is file is an "image".
+ """
+
+ possible_file_extension = ["png", "jpeg", "jpg", "ppm"]
+
+ for ext in possible_file_extension:
+ if file_path.endswith(f".{ext}"):
+ return True
+
+ if file_path.endswith(f".{ext.capitalize()}"):
+ return True
+
+ return False
+
+def get_coding_structure_from_args(args: argparse.Namespace) -> Dict[str, Any]:
+ """Perform some check on the argparse object used to collect the command
+ line parameters. Return a dictionary ready to be plugged into the
+ ``CodingStructure`` constructor.
+
+ Args:
+ args (argparse.Namespace): Command-line argument parser.
+
+ Returns:
+ Dict[str, Any]: Dictionary ready to be plugged into the ``CodingStructure``
+ constructor.
+ """
+ intra_period = args.intra_period
+ p_period = args.p_period
+
+ assert intra_period >= 0 and intra_period <= 255, (
+ f"Intra period should be in [0, 255]. Found {intra_period}"
+ )
+
+ assert p_period >= 0 and p_period <= 255, (
+ f"P period should be in [0, 255]. Found {p_period}"
+ )
+
+ if _is_image(args.input):
+ assert intra_period == 0 and p_period == 0, (
+ f"Encoding a PNG, JPEG or PPM image {args.input} must be done with"
+ "intra_period = 0 and p_period = 0. Found intra_period = "
+ f"{args.intra_period} and p_period = {args.p_period}"
+ )
+
+ coding_structure_config = {
+ "intra_period": intra_period,
+ "p_period": p_period,
+ "seq_name": os.path.basename(args.input).split(".")[0],
+ }
+ return coding_structure_config
+
+
+# ----- Arguments related to the frame encoder manager i.e. training preset etc.
+def get_manager_from_args(args: argparse.Namespace) -> Dict[str, Any]:
+ """Perform some check on the argparse object used to collect the command
+ line parameters. Return a dictionary ready to be plugged into the
+ ``FrameEncoderManager`` constructor.
+
+ Args:
+ args (argparse.Namespace): Command-line argument parser.
+
+ Returns:
+ Dict[str, Any]: Dictionary ready to be plugged into the
+ ``FrameEncoderManager`` constructor.
+ """
+ frame_encoder_manager = {
+ "preset_name": args.recipe,
+ "start_lr": args.start_lr,
+ "lmbda": args.lmbda,
+ "n_loops": args.n_train_loops,
+ "n_itr": args.n_itr,
+ }
+ return frame_encoder_manager
diff --git a/coolchic/enc/utils/yuv.py b/coolchic/enc/utils/yuv.py
deleted file mode 100644
index ba29dc51..00000000
--- a/coolchic/enc/utils/yuv.py
+++ /dev/null
@@ -1,258 +0,0 @@
-# Software Name: Cool-Chic
-# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
-# SPDX-License-Identifier: BSD 3-Clause "New"
-#
-# This software is distributed under the BSD-3-Clause license.
-#
-# Authors: see CONTRIBUTORS.md
-
-
-import os
-
-import numpy as np
-import torch
-from einops import rearrange
-from enc.utils.codingstructure import (
- FRAME_DATA_TYPE,
- POSSIBLE_BITDEPTH,
- DictTensorYUV,
- FrameData,
-)
-from PIL import Image
-from torch import Tensor
-from torchvision.transforms.functional import to_tensor
-
-
-def yuv_dict_clamp(yuv: DictTensorYUV, min_val: float, max_val: float) -> DictTensorYUV:
- """Clamp the y, u & v tensor.
-
- Args:
- yuv (DictTensorYUV): The data to clamp
- min_val (float): Minimum value for the clamp
- max_val (float): Maximum value for the clamp
-
- Returns:
- DictTensorYUV: The clamped data
-
- """
- clamped_yuv = DictTensorYUV(
- y=yuv.get("y").clamp(min_val, max_val),
- u=yuv.get("u").clamp(min_val, max_val),
- v=yuv.get("v").clamp(min_val, max_val),
- )
- return clamped_yuv
-
-
-def load_frame_data_from_file(filename: str, idx_display_order: int) -> FrameData:
- """Load the idx_display_order-th frame from a .yuv file or .png file. For the latter,
- idx_display_order must be equal to 0 as there is only one frame in a png.
-
- Args:
- filename (str): Absolute path of the file from which the frame is loaded.
- idx_display_order (int): Index of the frame in display order
-
- Returns:
- FrameData: The loaded frame, wrapped as a FrameData object.
- """
-
- if filename.endswith(".yuv"):
- # ! We only consider yuv420 and 444 planar
- bitdepth: POSSIBLE_BITDEPTH = 8 if "_8b" in filename else 10
- frame_data_type: FRAME_DATA_TYPE = "yuv420" if "420" in filename else "yuv444"
- data = read_yuv(filename, idx_display_order, frame_data_type, bitdepth)
-
- elif filename.endswith(".png"):
- bitdepth: POSSIBLE_BITDEPTH = 8
- frame_data_type: FRAME_DATA_TYPE = "rgb"
- data = to_tensor(Image.open(filename))
- data = rearrange(data, "c h w -> 1 c h w")
-
- return FrameData(bitdepth, frame_data_type, data)
-
-
-def read_yuv(filename: str, frame_idx: int, frame_data_type: FRAME_DATA_TYPE, bit_depth: POSSIBLE_BITDEPTH) -> DictTensorYUV:
- """From a filename /a/b/c.yuv, read the desired frame_index
- and return a dictionary of tensor containing the YUV values:
- {
- 'Y': [1, 1, H, W],
- 'U': [1, 1, H / S, W / S],
- 'V': [1, 1, H / S, W / S],
- }
- S is either 1 (444 sampling) or 2 (420)
- The YUV values are in [0., 1.]
-
- Args:
- filename (str): Absolute path of the video to load
- frame_idx (int): Index of the frame to load, starting at 0.
- bit depth (int):number of bits per component (8 or 10 bits).
- frame_data_type chroma sampling (420,444):
-
- Returns:
- DictTensorYUV: The YUV values (see format above) for 420.
- pytorch tensor for 444 sampling format (consistent with rgb representation)
- """
-
- # Parse height and width from the filename
- w, h = [
- int(tmp_str)
- for tmp_str in os.path.basename(filename).split(".")[0].split("_")[1].split("x")
- ]
-
- if frame_data_type == "yuv420":
- w_uv, h_uv = [int(x / 2) for x in [w, h]]
- else:
- w_uv, h_uv = w, h
-
- # Switch between 8 bit file and 10 bit
- byte_per_value = 1 if bit_depth == 8 else 2
-
- # We only handle YUV420 for now
- n_val_y = h * w
- n_val_uv = h_uv * w_uv
- n_val_per_frame = n_val_y + 2 * n_val_uv
-
- n_bytes_y = n_val_y * byte_per_value
- n_bytes_uv = n_val_uv * byte_per_value
- n_bytes_per_frame = n_bytes_y + 2 * n_bytes_uv
-
- # Read the required frame and put it in a 1d tensor
- raw_video = torch.tensor(
- np.memmap(
- filename,
- mode="r",
- shape=n_val_per_frame,
- offset=n_bytes_per_frame * frame_idx,
- dtype=np.uint16 if bit_depth == 10 else np.uint8,
- ).astype(np.float32)
- )
-
- # Read the different values from raw video and store them inside y, u and v
- ptr = 0
- y = raw_video[ptr : ptr + n_val_y].view(1, 1, h, w)
- ptr += n_val_y
- u = raw_video[ptr : ptr + n_val_uv].view(1, 1, h_uv, w_uv)
- ptr += n_val_uv
- v = raw_video[ptr : ptr + n_val_uv].view(1, 1, h_uv, w_uv)
-
- # PyTorch expect data in [0., 1.]; normalize by either 255 or 1023
- norm_factor = 2**bit_depth - 1
-
- if frame_data_type == "yuv420":
- video = DictTensorYUV(y=y / norm_factor, u=u / norm_factor, v=v / norm_factor)
- else:
- video = torch.cat([y, u, v], dim=1) / norm_factor
-
- return video
-
-
-def write_yuv(data: FrameData, filename: str, norm: bool = True) -> None:
- """Store a YUV frame as a YUV file named filename. All parameters of the YUV
- file (resolution, chroma subsampling, bitdepth) are contained in the FrameData
- object alongside the actual data. They are appended to the end of the filename
- If norm is True: the video data is expected to be in [0., 1.] so we
- multiply it by 255. Otherwise we let it as is.
-
- Args:
- data (FrameData): Data to save
- filename (str): Absolute path of the file where the YUV is saved.
- norm (bool): True to multiply the data by 2 ** bitdepth - 1.
- """
- assert data.frame_data_type in ["yuv420", "yuv444"], (
- "Found incorrect datatype in "
- f'write_yuv() function: {data.frame_data_type}. Data type should be "yuv420" or "yuv444".'
- )
-
- # Append .yuv at the end of the file to make sure it is present
- if not (filename[-4:] == ".yuv"):
- filename += ".yuv"
- # From here, there is no .yuv at the end of filename
- filename = filename[:-4]
-
- # Append spatial dimension to the filename, dummy framerate
- # and bit depth
- DUMMY_FRAMERATE = 1
- h, w = data.img_size
- # We need to add a p avec yuv444 otherwise YUView thinks its "YUV444 8-bit packed"
- filename = f"{filename}_{w}x{h}_{DUMMY_FRAMERATE}fps_{data.frame_data_type}p_{data.bitdepth}b.yuv"
-
- # Concatenate **all** channels into a 2D tensor [1.5 * H * W]
- if data.frame_data_type == "yuv420":
- raw_data = torch.cat([channels.flatten() for _, channels in data.data.items()])
- elif data.frame_data_type == "yuv444":
- raw_data = data.data.flatten()
-
- if norm:
- raw_data = raw_data * (2**data.bitdepth - 1)
-
- dtype = np.uint16 if data.bitdepth == 10 else np.uint8
-
- # Round the values and cast them to uint8 or uint16 tensor
- raw_data = torch.round(raw_data).cpu().numpy().astype(dtype)
-
- # Write this to the desired filename
- np.memmap.tofile(raw_data, filename)
-
-
-def rgb2yuv(rgb: Tensor) -> Tensor:
- """Convert a 4D RGB tensor [1, 3, H, W] into a 4D YUV444 tensor [1, 3, H, W].
- The RGB and YUV values are in the range [0, 255]
-
- Args:
- rgb (Tensor): 4D RGB tensor to convert in [0. 255.]
-
- Returns:
- Tensor: the resulting YUV444 tensor in [0. 255.]
- """
- assert (
- len(rgb.size()) == 4
- ), f"rgb2yuv input must be a 4D tensor [B, 3, H, W]. Data size: {rgb.size()}"
- assert (
- rgb.size()[1] == 3
- ), f"rgb2yuv input must have 3 channels. Data size: {rgb.size()}"
-
- # Split the [1, 3, H, W] into 3 [1, 1, H, W] tensors
- r, g, b = rgb.split(1, dim=1)
-
- # Compute the different channels
- y = torch.round(0.299 * r + 0.587 * g + 0.114 * b)
- u = torch.round(-0.1687 * r - 0.3313 * g + 0.5 * b + +128)
- v = torch.round(0.5 * r - 0.4187 * g - 0.0813 * b + 128)
-
- # Concatenate them into the resulting yuv 4D tensor.
- yuv = torch.cat((y, u, v), dim=1)
- return yuv
-
-
-def yuv2rgb(yuv: Tensor):
- """Convert a 4D YUV tensor [1, 3, H, W] into a 4D RGB tensor [1, 3, H, W].
- The RGB and YUV values are in the range [0, 255]
-
- Args:
- rgb (Tensor): 4D YUV444 tensor to convert in [0. 255.]
-
- Returns:
- Tensor: the resulting RGB tensor in [0. 255.]
- """
- assert (
- len(yuv.size()) == 4
- ), f"yuv2rgb input must be a 4D tensor [B, 3, H, W]. Data size: {yuv.size()}"
- assert (
- yuv.size()[1] == 3
- ), f"yuv2rgb input must have 3 channels. Data size: {yuv.size()}"
-
- y, u, v = yuv.split(1, dim=1)
- r = (
- 1.0 * y
- + -0.000007154783816076815 * u
- + 1.4019975662231445 * v
- - 179.45477266423404
- )
- g = 1.0 * y + -0.3441331386566162 * u + -0.7141380310058594 * v + 135.45870971679688
- b = (
- 1.0 * y
- + 1.7720025777816772 * u
- + 0.00001542569043522235 * v
- - 226.8183044444304
- )
- rgb = torch.cat((r, g, b), dim=1)
- return rgb
diff --git a/coolchic/enc/visu/__init__.py b/coolchic/enc/visu/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/coolchic/enc/visu/console.py b/coolchic/enc/visu/console.py
new file mode 100644
index 00000000..535ddf73
--- /dev/null
+++ b/coolchic/enc/visu/console.py
@@ -0,0 +1,258 @@
+# Software Name: Cool-Chic
+# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
+# SPDX-License-Identifier: BSD 3-Clause "New"
+#
+# This software is distributed under the BSD-3-Clause license.
+#
+# Authors: see CONTRIBUTORS.md
+
+from torch import nn
+
+
+def pretty_string_ups(upsampling: nn.Module, header: str) -> str:
+ """Get a nice string ready to be printed which displays the different layers
+ of the upsampling step.
+
+ Something like:
+
+ ..code ::
+
+ header
+
+ +-------------+
+ y0 -> | 8x8 TConv2d | -----+
+ +-------------+ |
+ v
+ +------------+ +-----+ +-------------+
+ y1 -> | 7x7 Conv2d | -> | cat | -> | 8x8 TConv2d | -----+
+ +------------+ +-----+ +-------------+ |
+ v
+ +------------+ +-----+ +-------------+
+ y2 ------------------------------> | 7x7 Conv2d | -> | cat | -> | 8x8 TConv2d | -> dense
+ +------------+ +-----+ +-------------+
+ Args:
+ upsampling (nn.Module): The upsampling module to print
+ header (str): A string to append before the layers
+
+ Returns:
+ str: A nice string presenting the upsampling.
+ """
+
+ lines = []
+
+ assert upsampling.n_ups_kernel == upsampling.n_ups_preconcat_kernel, (
+ "Textual representation of the upsampling module excepts to have the"
+ "same number of upsampling and pre-concatenation kernels. Found"
+ f"n_ups_kernel = {upsampling.n_ups_kernel} and n_ups_preconcat_kernel ="
+ f" {upsampling.n_ups_preconcat_kernel}."
+ )
+
+ # Useful thing to print
+ cat_box = [
+ "+-----+",
+ "| cat |",
+ "+-----+",
+ ]
+ # ! Warning no space before short_arrow here!
+ no_space_short_arrow = "->"
+ short_arrow = " -> "
+
+ shorten_names = {
+ "ParametrizedUpsamplingSeparableSymmetricConv2d": "Conv2d",
+ "ParametrizedUpsamplingSeparableSymmetricConvTranspose2d": "TConv2d",
+ }
+
+ lateral_offset = 0
+
+ for idx_ups in range(upsampling.n_ups_kernel + 1):
+ latent_name = f"y{idx_ups:<1} "
+ mid = f"{latent_name}{'-' * lateral_offset}{no_space_short_arrow} "
+ hori_border = ["", ""]
+
+ for i in range(len(hori_border)):
+ hori_border[i] += " " * len(mid)
+
+ # First (smallest) latent does not have a pre-concatenation filter
+ if idx_ups != 0:
+ conv_lay = upsampling.conv2ds[i]
+ layer_name = shorten_names.get(
+ type(conv_lay).__name__, type(conv_lay).__name__
+ )
+ k_size = conv_lay.target_k_size
+ inside_box = f" {k_size}x{k_size} {layer_name} "
+ mid += f"|{inside_box}|{short_arrow}"
+
+ # Add a concatenation box
+ for i in range(len(hori_border)):
+ hori_border[i] += f"+{'-' * len(inside_box)}+{' ' * len(short_arrow)}"
+
+ mid += cat_box[1] + short_arrow
+ hori_border[0] += cat_box[0]
+ hori_border[1] += cat_box[2]
+
+ for i in range(len(hori_border)):
+ hori_border[i] += f"{' ' * len(short_arrow)}"
+
+ lateral_offset = len(mid) - len(latent_name) - len(no_space_short_arrow) - 1
+
+ tconv_lay = upsampling.conv_transpose2ds[i]
+ layer_name = shorten_names.get(
+ type(tconv_lay).__name__, type(tconv_lay).__name__
+ )
+ k_size = tconv_lay.target_k_size
+ inside_box = f" {k_size}x{k_size} {layer_name} "
+ mid += f"|{inside_box}|"
+
+ for i in range(len(hori_border)):
+ hori_border[i] += f"+{'-' * len(inside_box)}+"
+
+ concat_arrow = (
+ " " # Skip one space as in long arrow
+ # Account for the long arrow and half of the cat box.
+ + "-" * (len(no_space_short_arrow) + len(cat_box[0]) // 2)
+ )
+
+ # Last line is a bit specific
+ if idx_ups == upsampling.n_ups_kernel:
+ mid += f"{short_arrow}dense"
+ lines += [hori_border[0], mid, hori_border[1]]
+
+ else:
+ mid += concat_arrow + "+"
+ hori_border[1] += " " * len(concat_arrow) + "|"
+ last_line = " " * (len(mid) - 1) + "v"
+ lines += [hori_border[0], mid, hori_border[1], last_line]
+
+ return header + "\n".join(lines)
+
+
+def pretty_string_nn(
+ layers: nn.Sequential, header: str, input_str: str, output_str: str
+) -> str:
+ """Get a nice string ready to be printed which displays the different layers
+ of a neural network.
+
+ Something like:
+
+ ..code ::
+
+ header
+ +--------------------------+
+ | |
+ | v
+ | +---------------+ +-----+ +------+ +---------------+
+ input_str ---> | Linear 8 -> 8 | -> | + | -> | ReLU | ---> | Linear 8 -> 2 | ---> output_str
+ +---------------+ +-----+ +------+ +---------------+
+
+ Args:
+ layers (nn.Sequential): The successive layer of the neural network.
+ header (str): A string to append before the layers
+ input_str (str): A string describing the input of the NN.
+ output_str (str): A string describing the output of the NN.
+
+ Returns:
+ str: A nice string presenting the neural network architecture.
+ """
+
+ lines = [
+ "", # For residual connexions
+ "", # For residual connexions
+ "", # For residual connexions
+ "", # For blocks in themselves
+ "", # For blocks in themselves
+ "", # For blocks in themselves
+ ]
+
+ # Name of the layer appears here
+ idx_mid_block = len(lines) - 2
+
+ # Horizontal borders above and below the block name
+ idx_top_block = len(lines) - 3
+ idx_bot_block = len(lines) - 1
+ top_bot_blocks = [idx_bot_block, idx_top_block]
+
+ # First three lines are for residual connexions
+ res_blocks = list(range(3))
+
+ lines[idx_mid_block] = input_str
+ for idx in top_bot_blocks + res_blocks:
+ lines[idx] += " " * len(lines[idx_mid_block])
+
+ # Useful thing to print
+ plus_box = [
+ "+-----+",
+ "| + |",
+ "+-----+",
+ ]
+ short_arrow = " -> "
+ long_arrow = " ---> "
+
+ shorten_names = {"ArmLinear": "Linear", "SynthesisConv2d": "Conv2d"}
+
+ # Print one layer after the other
+ for lay in layers:
+ is_non_linearity = (
+ isinstance(lay, nn.ReLU)
+ or isinstance(lay, nn.Identity)
+ or isinstance(lay, nn.LeakyReLU)
+ )
+
+ arrow_str = short_arrow if is_non_linearity else long_arrow
+
+ layer_name = shorten_names.get(type(lay).__name__, type(lay).__name__)
+ if layer_name == "Identity":
+ continue
+
+ inside_box = f" {layer_name} "
+ if not is_non_linearity:
+ inside_box += f"{lay.in_channels} -> {lay.out_channels} "
+
+ if hasattr(lay, "kernel_size"):
+ inside_box = f" {lay.kernel_size}x{lay.kernel_size}{inside_box}"
+
+ lines[idx_mid_block] += f"{arrow_str}|{inside_box}|"
+
+ for idx in top_bot_blocks:
+ lines[idx] += f"{' ' * len(arrow_str)}+{'-' * len(inside_box)}+"
+
+ is_residual = False
+ if hasattr(lay, "residual"):
+ is_residual = lay.residual
+
+ if is_residual:
+ half_arrow_len = len(arrow_str) // 2
+ for idx in res_blocks:
+ lines[idx] += " " * half_arrow_len
+
+ res_arrow_len = (
+ half_arrow_len
+ - 1 # The remaining half of the input arrow, minus 1
+ + len(inside_box)
+ + 2 # The entire box + 2 for the edges
+ + len(short_arrow) # The short arrow from the layer box to the add box
+ + len(plus_box[0]) // 2 # Half of the plus box.
+ )
+ lines[0] += f"+{'-' * res_arrow_len}+"
+ lines[1] += f"|{' ' * res_arrow_len}|"
+ lines[2] += f"|{' ' * res_arrow_len}v"
+
+ for idx in res_blocks:
+ lines[idx] += " " * (len(plus_box[0]) // 2)
+
+ else:
+ for idx in res_blocks:
+ lines[idx] += f"{' ' * len(arrow_str)} {' ' * len(inside_box)} "
+
+ if is_residual:
+ arrow_str = short_arrow
+ lines[idx_mid_block] += f"{arrow_str}{plus_box[1]}"
+ for idx in top_bot_blocks:
+ lines[idx] += f"{' ' * len(arrow_str)}{plus_box[0]}"
+
+ lines[3] = list(lines[3])
+ lines[3][-(res_arrow_len + half_arrow_len + 2)] = "|"
+ lines[3] = "".join(lines[3])
+
+ lines[idx_mid_block] += long_arrow + output_str
+
+ return header + "\n".join(lines)
diff --git a/coolchic/encode.py b/coolchic/encode.py
index 11bf19d2..0962fa67 100644
--- a/coolchic/encode.py
+++ b/coolchic/encode.py
@@ -1,27 +1,31 @@
- # Software Name: Cool-Chic
+# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
-# Authors: Theo Ladune
-# Pierrick Philippe
+# Authors: see CONTRIBUTORS.md
+
import os
-import sys
-import torch
import subprocess
-import configargparse
+import sys
+import configargparse
+import torch
from enc.component.coolchic import CoolChicEncoderParameter
from enc.component.video import (
- VideoEncoder,
FrameEncoderManager,
+ VideoEncoder,
load_video_encoder,
)
from enc.utils.codingstructure import CodingStructure
-from enc.utils.misc import get_best_device
-
+from enc.utils.misc import TrainingExitCode, get_best_device
+from enc.utils.parsecli import (
+ get_coding_structure_from_args,
+ get_coolchic_param_from_args,
+ get_manager_from_args,
+)
"""
Use this file to train i.e. encode a GOP i.e. something which starts with one
@@ -30,7 +34,6 @@
"""
if __name__ == "__main__":
-
# =========================== Parse arguments =========================== #
# By increasing priority order, the arguments work as follows:
#
@@ -47,20 +50,20 @@
parser = configargparse.ArgumentParser()
# -------- These arguments are not in the configuration files
parser.add(
- "-i", "--input",
+ "-i",
+ "--input",
help="Path of the input image. Either .png (RGB444) or .yuv (YUV420)",
type=str,
)
parser.add(
- "-o", "--output",
+ "-o",
+ "--output",
help="Path of the compressed bitstream. If empty, no bitstream is written",
type=str,
default="",
)
- parser.add(
- "--workdir", help="Path of the working_directory", type=str, default="."
- )
+ parser.add("--workdir", help="Path of the working_directory", type=str, default=".")
parser.add("--lmbda", help="Rate constraint", type=float, default=1e-3)
parser.add(
"--job_duration_min",
@@ -70,13 +73,9 @@
)
# -------- Configuration files
- parser.add(
- "--enc_cfg", is_config_file=True, help="Encoder configuration file"
- )
+ parser.add("--enc_cfg", is_config_file=True, help="Encoder configuration file")
- parser.add(
- "--dec_cfg", is_config_file=True, help="Decoder configuration file"
- )
+ parser.add("--dec_cfg", is_config_file=True, help="Decoder configuration file")
# -------- These arguments are in the configuration files
@@ -94,18 +93,14 @@
default=0,
)
- parser.add(
- "--start_lr", help="Initial learning rate", type=float, default=1e-2
- )
+ parser.add("--start_lr", help="Initial learning rate", type=float, default=1e-2)
parser.add(
"--n_itr",
help="Maximum number of iterations per phase",
type=int,
default=int(1e4),
)
- parser.add(
- "--n_train_loops", help="Number of training loops", type=int, default=1
- )
+ parser.add("--n_train_loops", help="Number of training loops", type=int, default=1)
parser.add(
"--recipe",
help='Recipe type. Either "c3x" or "debug".',
@@ -118,19 +113,27 @@
"--layers_synthesis",
type=str,
default="40-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none",
- help="Syntax example for the synthesis:"
- " 12-1-linear-relu,12-1-residual-relu,X-1-linear-relu,X-3-residual-none"
- "This is a 4 layers synthesis. Now the output layer (computing the final RGB"
- "values) must be specified i.e. a 12,12 should now be called a 12,12,3. Each layer"
- "is described using the following syntax:"
- "---. "
- " is the number of output features. If set to X, this is replaced by the"
- "number of required output features i.e. 3 for a RGB or YUV frame."
- " is the spatial dimension of the kernel. Use 1 to mimic an MLP."
- " is either 'linear' for a standard conv or 'residual' for a residual"
- " block i.e. layer(x) = x + conv(x). Can be'none' for no"
- " non-linearity, 'relu' for a ReLU"
+ help="Syntax example: "
+ " "
+ " 12-1-linear-relu,12-1-residual-relu,X-1-linear-relu,X-3-residual-none "
+ " "
+ " This is a 4 layers synthesis. Each layer is separated by comas and is "
+ "described using the following syntax: "
+ " "
+ " ---. "
+ " "
+ " is the number of output features. If set to X, this is replaced by the "
+ "number of required output features i.e. 3 for a RGB or YUV frame. "
+ " "
+ " is the spatial dimension of the kernel. Use 1 to mimic an MLP. "
+ " "
+ " is either 'linear' for a standard conv or 'residual' for a convolution "
+ "with a residual connexion block i.e. layer(x) = x + conv(x). "
+ " "
+ " Can be 'none' for no non-linearity, 'relu' for a ReLU "
+ "non-linearity. ",
)
+
parser.add(
"--arm",
type=str,
@@ -141,30 +144,37 @@
)
parser.add(
- "--n_ft_per_res",
- type=str,
- default="1,1,1,1,1,1,1",
- help="Number of feature for each latent resolution. e.g. --n_ft_per_res=1,2,2,2,3,3,3"
- " for 7 latent grids with variable resolutions.",
+ "--ups_k_size",
+ type=int,
+ default=8,
+ help="Upsampling kernel size for the transposed convolutions. "
+ "Must be even and >= 4.",
)
parser.add(
- "--upsampling_kernel_size",
- help="upsampling kernel size (โฅ4 and multiple of 2)",
+ "--ups_preconcat_k_size",
type=int,
- default=8,
+ default=7,
+ help="Upsampling kernel size for the pre-concatenation convolutions. "
+ "Must be odd.",
)
+
parser.add(
- "--static_upsampling_kernel",
- help="Use this flag to **not** learn the upsampling kernel",
- action="store_true",
- default=False,
+ "--n_ft_per_res",
+ type=str,
+ default="1,1,1,1,1,1,1",
+ help="Number of feature for each latent resolution. e.g. "
+ " --n_ft_per_res_residue=1,1,1,1,1,1,1 "
+ " for 7 latent grids with variable resolutions. "
+ " Parameterize the residue decoder.",
)
args = parser.parse_args()
print(args)
print("----------")
- print(parser.format_values()) # useful for logging where different settings came from
+ print(
+ parser.format_values()
+ ) # useful for logging where different settings came from
# =========================== Parse arguments =========================== #
# =========================== Parse arguments =========================== #
@@ -175,27 +185,25 @@
video_encoder = load_video_encoder(path_video_encoder)
else:
-
start_print = (
- '\n\n'
- '*----------------------------------------------------------------------------------------------------------*\n'
- '| |\n'
- '| |\n'
- '| ,gggg, |\n'
+ "\n\n"
+ "*----------------------------------------------------------------------------------------------------------*\n"
+ "| |\n"
+ "| |\n"
+ "| ,gggg, |\n"
'| ,88"""Y8b, ,dPYb, ,dPYb, |\n'
- '| d8" `Y8 IP\'`Yb IP\'`Yb |\n'
- '| d8\' 8b d8 I8 8I I8 8I gg |\n'
- '| ,8I "Y88P\' I8 8\' I8 8\' "" |\n'
- '| I8\' ,ggggg, ,ggggg, I8 dP aaaaaaaa ,gggg, I8 dPgg, gg ,gggg, |\n'
+ "| d8\" `Y8 IP'`Yb IP'`Yb |\n"
+ "| d8' 8b d8 I8 8I I8 8I gg |\n"
+ "| ,8I \"Y88P' I8 8' I8 8' \"\" |\n"
+ "| I8' ,ggggg, ,ggggg, I8 dP aaaaaaaa ,gggg, I8 dPgg, gg ,gggg, |\n"
'| d8 dP" "Y8ggg dP" "Y8ggg I8dP """""""" dP" "Yb I8dP" "8I 88 dP" "Yb |\n'
- '| Y8, i8\' ,8I i8\' ,8I I8P i8\' I8P I8 88 i8\' |\n'
- '| `Yba,,_____, ,d8, ,d8\' ,d8, ,d8\' ,d8b,_ ,d8,_ _,d8 I8,_,88,_,d8,_ _ |\n'
+ "| Y8, i8' ,8I i8' ,8I I8P i8' I8P I8 88 i8' |\n"
+ "| `Yba,,_____, ,d8, ,d8' ,d8, ,d8' ,d8b,_ ,d8,_ _,d8 I8,_,88,_,d8,_ _ |\n"
'| `"Y8888888 P"Y8888P" P"Y8888P" 8P\'"Y88 P""Y8888PP88P `Y88P""Y8P""Y8888PP |\n'
- '| |\n'
- '| |\n'
- '| version 3.3 ยฉ 2023-2024 Orange |\n'
- '*----------------------------------------------------------------------------------------------------------*\n'
-
+ "| |\n"
+ "| |\n"
+ "| version 3.4, Nov. 2024 ยฉ 2023-2024 Orange |\n"
+ "*----------------------------------------------------------------------------------------------------------*\n"
)
print(start_print)
@@ -207,72 +215,19 @@
f_out.write(str(args))
f_out.write("\n")
f_out.write("----------\n")
- f_out.write(parser.format_values()) # useful for logging where different settings came from
-
- # ----- Create coding configuration
- assert args.intra_period >= 0 and args.intra_period <= 255, (
- f"Intra period should be " f" in [0, 255]. Found {args.intra_period}"
- )
-
- assert args.p_period >= 0 and args.p_period <= 255, (
- f"P period should be " f" in [0, 255]. Found {args.p_period}"
- )
-
- is_image = (
- args.input.endswith(".png")
- or args.input.endswith(".PNG")
- or args.input.endswith(".jpeg")
- or args.input.endswith(".JPEG")
- or args.input.endswith(".jpg")
- or args.input.endswith(".JPG")
- )
-
- if is_image:
- assert args.intra_period == 0 and args.p_period == 0, (
- f"Encoding a PNG or JPEG image {args.input} must be done with "
- "intra_period = 0 and p_period = 0. Found intra_period = "
- f"{args.intra_period} and p_period = {args.p_period}"
- )
-
- coding_config = CodingStructure(
- intra_period=args.intra_period,
- p_period=args.p_period,
- seq_name=os.path.basename(args.input).split(".")[0],
- )
-
- # Parse arguments
- layers_synthesis = [x for x in args.layers_synthesis.split(",") if x != ""]
- n_ft_per_res = [int(x) for x in args.n_ft_per_res.split(",") if x != ""]
-
- assert set(n_ft_per_res) == {1}, (
- f"--n_ft_per_res should only contains 1. Found {args.n_ft_per_res}"
- )
-
- assert len(args.arm.split(",")) == 2, (
- f"--arm format should be X,Y." f" Found {args.arm}"
- )
-
- dim_arm, n_hidden_layers_arm = [int(x) for x in args.arm.split(",")]
+ f_out.write(
+ parser.format_values()
+ ) # useful for logging where different settings came from
+ # ----- Parse arguments & construct video encoder
+ coding_structure = CodingStructure(**get_coding_structure_from_args(args))
coolchic_encoder_parameter = CoolChicEncoderParameter(
- layers_synthesis=layers_synthesis,
- dim_arm=dim_arm,
- n_hidden_layers_arm=n_hidden_layers_arm,
- n_ft_per_res=n_ft_per_res,
- upsampling_kernel_size=args.upsampling_kernel_size,
- static_upsampling_kernel=args.static_upsampling_kernel,
- )
-
- frame_encoder_manager = FrameEncoderManager(
- preset_name=args.recipe,
- start_lr=args.start_lr,
- lmbda=args.lmbda,
- n_loops=args.n_train_loops,
- n_itr=args.n_itr,
+ **get_coolchic_param_from_args(args)
)
+ frame_encoder_manager = FrameEncoderManager(**get_manager_from_args(args))
video_encoder = VideoEncoder(
- coding_structure=coding_config,
+ coding_structure=coding_structure,
shared_coolchic_parameter=coolchic_encoder_parameter,
shared_frame_encoder_manager=frame_encoder_manager,
)
@@ -281,29 +236,8 @@
device = get_best_device()
print(f'{"Device":<20}: {device}')
- # # ====================== Torchscript JIT parameters ===================== #
- # # From https://github.com/pytorch/pytorch/issues/52286
- # # This is no longer the case with the with torch.jit.fuser
- # # ! This gives a significant (+25 %) speed up
- # torch._C._jit_set_profiling_executor(False)
- # torch._C._jit_set_texpr_fuser_enabled(False)
- # torch._C._jit_set_profiling_mode(False)
-
- # torch.set_float32_matmul_precision("high")
- # # ====================== Torchscript JIT parameters ===================== #
-
- if device == "cpu":
- # the number of cores is adjusted wrt to the slurm variable if exists
- n_cores = os.getenv("SLURM_JOB_CPUS_PER_NODE")
- # otherwise use the machine cpu count
- if n_cores is None:
- n_cores = os.cpu_count()
-
- n_cores = int(n_cores)
- print(f'{"CPU cores":<20}: {n_cores}')
-
- elif device == "cuda:0":
- # ! This one makes the training way faster!
+ # This makes the training faster
+ if device == "cuda:0":
torch.backends.cudnn.benchmark = True
print(f"\n{video_encoder.coding_structure.pretty_string()}\n")
@@ -319,10 +253,10 @@
video_encoder.save(video_encoder_savepath)
# Bitstream
- if args.output != "":
+ if args.output != "" and exit_code == TrainingExitCode.END:
from enc.bitstream.encode import encode_video
+
# video_encoder = load_video_encoder(video_encoder_savepath)
encode_video(video_encoder, args.output, hls_sig_blksize=16)
sys.exit(exit_code.value)
-
diff --git a/docs/source/_templates/partials/webfonts.html b/docs/source/_templates/partials/webfonts.html
new file mode 100644
index 00000000..de8d9b38
--- /dev/null
+++ b/docs/source/_templates/partials/webfonts.html
@@ -0,0 +1,11 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/source/assets/clic20-pro-valid/all_complexity.png b/docs/source/assets/clic20-pro-valid/all_complexity.png
new file mode 100644
index 00000000..099c3886
Binary files /dev/null and b/docs/source/assets/clic20-pro-valid/all_complexity.png differ
diff --git a/docs/source/assets/clic20-pro-valid/concat_img.png b/docs/source/assets/clic20-pro-valid/concat_img.png
new file mode 100644
index 00000000..288959c3
Binary files /dev/null and b/docs/source/assets/clic20-pro-valid/concat_img.png differ
diff --git a/docs/source/assets/clic20-pro-valid/perf_complexity.png b/docs/source/assets/clic20-pro-valid/perf_complexity.png
index 4ab785d1..8c423198 100644
Binary files a/docs/source/assets/clic20-pro-valid/perf_complexity.png and b/docs/source/assets/clic20-pro-valid/perf_complexity.png differ
diff --git a/docs/source/assets/clic20-pro-valid/perf_decoding_time.png b/docs/source/assets/clic20-pro-valid/perf_decoding_time.png
index c5dd3ead..9a35289c 100644
Binary files a/docs/source/assets/clic20-pro-valid/perf_decoding_time.png and b/docs/source/assets/clic20-pro-valid/perf_decoding_time.png differ
diff --git a/docs/source/assets/clic20-pro-valid/rd.png b/docs/source/assets/clic20-pro-valid/rd.png
index 4b0e2e3e..e5501047 100644
Binary files a/docs/source/assets/clic20-pro-valid/rd.png and b/docs/source/assets/clic20-pro-valid/rd.png differ
diff --git a/docs/source/assets/concat_img.sh b/docs/source/assets/concat_img.sh
new file mode 100755
index 00000000..7784b886
--- /dev/null
+++ b/docs/source/assets/concat_img.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+complexity_height=700
+final_width=1000
+
+for dataset in kodak clic20-pro-valid
+do
+ # Concatenate the images horizontally and resize them if they don't have the same height
+ convert +append $dataset/perf_complexity.png $dataset/perf_decoding_time.png -resize x$complexity_height $dataset/all_complexity.png
+ convert -append $dataset/rd.png $dataset/all_complexity.png -resize $final_width $dataset/concat_img.png
+done
+
+dataset=jvet
+for class in B C D E F BCDEF
+do
+ # Concatenate the images horizontally and resize them if they don't have the same height
+ convert +append $dataset/perf_complexity_class$class.png $dataset/perf_decoding_time_class$class.png -resize x$complexity_height $dataset/all_complexity_class$class.png
+ convert -append $dataset/rd_class$class.png $dataset/all_complexity_class$class.png -resize $final_width $dataset/concat_img_class$class.png
+done
diff --git a/docs/source/assets/coolchic-logo-light.png b/docs/source/assets/coolchic-logo-light.png
index 708b49e5..697a9a49 100644
Binary files a/docs/source/assets/coolchic-logo-light.png and b/docs/source/assets/coolchic-logo-light.png differ
diff --git a/docs/source/assets/favicon.pdf b/docs/source/assets/favicon.pdf
deleted file mode 100644
index f6cdd17b..00000000
Binary files a/docs/source/assets/favicon.pdf and /dev/null differ
diff --git a/docs/source/assets/jvet/all_complexity_classB.png b/docs/source/assets/jvet/all_complexity_classB.png
new file mode 100644
index 00000000..4f4b4f12
Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classB.png differ
diff --git a/docs/source/assets/jvet/all_complexity_classBCDEF.png b/docs/source/assets/jvet/all_complexity_classBCDEF.png
new file mode 100644
index 00000000..be26ebc7
Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classBCDEF.png differ
diff --git a/docs/source/assets/jvet/all_complexity_classC.png b/docs/source/assets/jvet/all_complexity_classC.png
new file mode 100644
index 00000000..3347084d
Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classC.png differ
diff --git a/docs/source/assets/jvet/all_complexity_classD.png b/docs/source/assets/jvet/all_complexity_classD.png
new file mode 100644
index 00000000..954c1cd8
Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classD.png differ
diff --git a/docs/source/assets/jvet/all_complexity_classE.png b/docs/source/assets/jvet/all_complexity_classE.png
new file mode 100644
index 00000000..ce2570fa
Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classE.png differ
diff --git a/docs/source/assets/jvet/all_complexity_classF.png b/docs/source/assets/jvet/all_complexity_classF.png
new file mode 100644
index 00000000..18491daa
Binary files /dev/null and b/docs/source/assets/jvet/all_complexity_classF.png differ
diff --git a/docs/source/assets/jvet/concat_img_classB.png b/docs/source/assets/jvet/concat_img_classB.png
new file mode 100644
index 00000000..fbdd1654
Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classB.png differ
diff --git a/docs/source/assets/jvet/concat_img_classBCDEF.png b/docs/source/assets/jvet/concat_img_classBCDEF.png
new file mode 100644
index 00000000..fb8d6c49
Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classBCDEF.png differ
diff --git a/docs/source/assets/jvet/concat_img_classC.png b/docs/source/assets/jvet/concat_img_classC.png
new file mode 100644
index 00000000..d74cca1e
Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classC.png differ
diff --git a/docs/source/assets/jvet/concat_img_classD.png b/docs/source/assets/jvet/concat_img_classD.png
new file mode 100644
index 00000000..882efea9
Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classD.png differ
diff --git a/docs/source/assets/jvet/concat_img_classE.png b/docs/source/assets/jvet/concat_img_classE.png
new file mode 100644
index 00000000..20935e56
Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classE.png differ
diff --git a/docs/source/assets/jvet/concat_img_classF.png b/docs/source/assets/jvet/concat_img_classF.png
new file mode 100644
index 00000000..534516db
Binary files /dev/null and b/docs/source/assets/jvet/concat_img_classF.png differ
diff --git a/docs/source/assets/jvet/perf_complexity_classB.png b/docs/source/assets/jvet/perf_complexity_classB.png
index 6be340b4..45b763d1 100644
Binary files a/docs/source/assets/jvet/perf_complexity_classB.png and b/docs/source/assets/jvet/perf_complexity_classB.png differ
diff --git a/docs/source/assets/jvet/perf_complexity_classBCDEF.png b/docs/source/assets/jvet/perf_complexity_classBCDEF.png
index d1ba6a33..8a900543 100644
Binary files a/docs/source/assets/jvet/perf_complexity_classBCDEF.png and b/docs/source/assets/jvet/perf_complexity_classBCDEF.png differ
diff --git a/docs/source/assets/jvet/perf_complexity_classC.png b/docs/source/assets/jvet/perf_complexity_classC.png
index 81c0fd52..ae52b65d 100644
Binary files a/docs/source/assets/jvet/perf_complexity_classC.png and b/docs/source/assets/jvet/perf_complexity_classC.png differ
diff --git a/docs/source/assets/jvet/perf_complexity_classD.png b/docs/source/assets/jvet/perf_complexity_classD.png
index 790b997d..1dc70366 100644
Binary files a/docs/source/assets/jvet/perf_complexity_classD.png and b/docs/source/assets/jvet/perf_complexity_classD.png differ
diff --git a/docs/source/assets/jvet/perf_complexity_classE.png b/docs/source/assets/jvet/perf_complexity_classE.png
index 4e9d6b92..0df7e289 100644
Binary files a/docs/source/assets/jvet/perf_complexity_classE.png and b/docs/source/assets/jvet/perf_complexity_classE.png differ
diff --git a/docs/source/assets/jvet/perf_complexity_classF.png b/docs/source/assets/jvet/perf_complexity_classF.png
index 4c0943a7..2c198731 100644
Binary files a/docs/source/assets/jvet/perf_complexity_classF.png and b/docs/source/assets/jvet/perf_complexity_classF.png differ
diff --git a/docs/source/assets/jvet/perf_decoding_time_classB.png b/docs/source/assets/jvet/perf_decoding_time_classB.png
index a0063b91..87d595de 100644
Binary files a/docs/source/assets/jvet/perf_decoding_time_classB.png and b/docs/source/assets/jvet/perf_decoding_time_classB.png differ
diff --git a/docs/source/assets/jvet/perf_decoding_time_classBCDEF.png b/docs/source/assets/jvet/perf_decoding_time_classBCDEF.png
index 6fd0daeb..003e3d6c 100644
Binary files a/docs/source/assets/jvet/perf_decoding_time_classBCDEF.png and b/docs/source/assets/jvet/perf_decoding_time_classBCDEF.png differ
diff --git a/docs/source/assets/jvet/perf_decoding_time_classC.png b/docs/source/assets/jvet/perf_decoding_time_classC.png
index 14699d8e..d50eb4f4 100644
Binary files a/docs/source/assets/jvet/perf_decoding_time_classC.png and b/docs/source/assets/jvet/perf_decoding_time_classC.png differ
diff --git a/docs/source/assets/jvet/perf_decoding_time_classD.png b/docs/source/assets/jvet/perf_decoding_time_classD.png
index d6365075..b0253a86 100644
Binary files a/docs/source/assets/jvet/perf_decoding_time_classD.png and b/docs/source/assets/jvet/perf_decoding_time_classD.png differ
diff --git a/docs/source/assets/jvet/perf_decoding_time_classE.png b/docs/source/assets/jvet/perf_decoding_time_classE.png
index 37992f72..1c95da44 100644
Binary files a/docs/source/assets/jvet/perf_decoding_time_classE.png and b/docs/source/assets/jvet/perf_decoding_time_classE.png differ
diff --git a/docs/source/assets/jvet/perf_decoding_time_classF.png b/docs/source/assets/jvet/perf_decoding_time_classF.png
index 7b3ce8b7..e4215494 100644
Binary files a/docs/source/assets/jvet/perf_decoding_time_classF.png and b/docs/source/assets/jvet/perf_decoding_time_classF.png differ
diff --git a/docs/source/assets/jvet/rd_classB.png b/docs/source/assets/jvet/rd_classB.png
index 94bac994..87106368 100644
Binary files a/docs/source/assets/jvet/rd_classB.png and b/docs/source/assets/jvet/rd_classB.png differ
diff --git a/docs/source/assets/jvet/rd_classBCDEF.png b/docs/source/assets/jvet/rd_classBCDEF.png
index 911b2725..2e53db0f 100644
Binary files a/docs/source/assets/jvet/rd_classBCDEF.png and b/docs/source/assets/jvet/rd_classBCDEF.png differ
diff --git a/docs/source/assets/jvet/rd_classC.png b/docs/source/assets/jvet/rd_classC.png
index 6097c9b7..e8e74172 100644
Binary files a/docs/source/assets/jvet/rd_classC.png and b/docs/source/assets/jvet/rd_classC.png differ
diff --git a/docs/source/assets/jvet/rd_classD.png b/docs/source/assets/jvet/rd_classD.png
index 5f244448..36bdb3ca 100644
Binary files a/docs/source/assets/jvet/rd_classD.png and b/docs/source/assets/jvet/rd_classD.png differ
diff --git a/docs/source/assets/jvet/rd_classE.png b/docs/source/assets/jvet/rd_classE.png
index 9eded54c..8ac86b3e 100644
Binary files a/docs/source/assets/jvet/rd_classE.png and b/docs/source/assets/jvet/rd_classE.png differ
diff --git a/docs/source/assets/jvet/rd_classF.png b/docs/source/assets/jvet/rd_classF.png
index dbc51d86..6e432665 100644
Binary files a/docs/source/assets/jvet/rd_classF.png and b/docs/source/assets/jvet/rd_classF.png differ
diff --git a/docs/source/assets/kodak/all_complexity.png b/docs/source/assets/kodak/all_complexity.png
new file mode 100644
index 00000000..58af51e8
Binary files /dev/null and b/docs/source/assets/kodak/all_complexity.png differ
diff --git a/docs/source/assets/kodak/concat_img.png b/docs/source/assets/kodak/concat_img.png
new file mode 100644
index 00000000..b71e96ad
Binary files /dev/null and b/docs/source/assets/kodak/concat_img.png differ
diff --git a/docs/source/assets/kodak/perf_complexity.png b/docs/source/assets/kodak/perf_complexity.png
index f9bf3d77..cba351b9 100644
Binary files a/docs/source/assets/kodak/perf_complexity.png and b/docs/source/assets/kodak/perf_complexity.png differ
diff --git a/docs/source/assets/kodak/perf_decoding_time.png b/docs/source/assets/kodak/perf_decoding_time.png
index 59697e07..8b6a6f34 100644
Binary files a/docs/source/assets/kodak/perf_decoding_time.png and b/docs/source/assets/kodak/perf_decoding_time.png differ
diff --git a/docs/source/assets/kodak/rd.png b/docs/source/assets/kodak/rd.png
index de7e3d26..508d15bd 100644
Binary files a/docs/source/assets/kodak/rd.png and b/docs/source/assets/kodak/rd.png differ
diff --git a/docs/source/assets/logo_concat.png b/docs/source/assets/logo_concat.png
new file mode 100644
index 00000000..2beac2fe
Binary files /dev/null and b/docs/source/assets/logo_concat.png differ
diff --git a/docs/source/assets/rd-image-clic20-validpro.png b/docs/source/assets/rd-image-clic20-validpro.png
deleted file mode 100644
index 4b74f68a..00000000
Binary files a/docs/source/assets/rd-image-clic20-validpro.png and /dev/null differ
diff --git a/docs/source/assets/rd-image-jvet.png b/docs/source/assets/rd-image-jvet.png
deleted file mode 100644
index a9963b0f..00000000
Binary files a/docs/source/assets/rd-image-jvet.png and /dev/null differ
diff --git a/docs/source/assets/rd-video-ldp-clic24-validsubset.png b/docs/source/assets/rd-video-ldp-clic24-validsubset.png
deleted file mode 100644
index dbeb0e2e..00000000
Binary files a/docs/source/assets/rd-video-ldp-clic24-validsubset.png and /dev/null differ
diff --git a/docs/source/assets/rd-video-ra-clic24-validsubset.png b/docs/source/assets/rd-video-ra-clic24-validsubset.png
deleted file mode 100644
index 544135ac..00000000
Binary files a/docs/source/assets/rd-video-ra-clic24-validsubset.png and /dev/null differ
diff --git a/docs/source/code_documentation/encoder/component/core/index.rst b/docs/source/code_documentation/encoder/component/core/index.rst
index 71644ace..9d63a1f9 100644
--- a/docs/source/code_documentation/encoder/component/core/index.rst
+++ b/docs/source/code_documentation/encoder/component/core/index.rst
@@ -8,7 +8,7 @@ how these modules are related.
.. toctree::
- ARM
- Quantizer
- Upsampling
- Synthesis
+ arm
+ quantizer
+ upsampling
+ synthesis
diff --git a/docs/source/code_documentation/encoder/component/core/upsampling.rst b/docs/source/code_documentation/encoder/component/core/upsampling.rst
index 1f908985..5e0f192a 100644
--- a/docs/source/code_documentation/encoder/component/core/upsampling.rst
+++ b/docs/source/code_documentation/encoder/component/core/upsampling.rst
@@ -9,5 +9,8 @@ Upsampling
.. autoclass:: Upsampling
:members:
-.. autoclass:: UpsamplingConvTranspose2d
+.. autoclass:: UpsamplingSeparableSymmetricConvTranspose2d
+ :members:
+
+.. autoclass:: UpsamplingSeparableSymmetricConv2d
:members:
diff --git a/docs/source/code_documentation/encoder/component/index.rst b/docs/source/code_documentation/encoder/component/index.rst
index 6eca23d0..0a18f68e 100644
--- a/docs/source/code_documentation/encoder/component/index.rst
+++ b/docs/source/code_documentation/encoder/component/index.rst
@@ -3,25 +3,26 @@ Component
Components gather all the modules of the enc.
- * ``VideoEncoder`` is used to compress a video *i.e.* one intra frame
- followed by zero or more inter frames. It contains one or more
- ``FrameEncoder``.
+ * ``video.py`` implements the class ``VideoEncoder`` which is used to compress a
+ video *i.e.* one intra frame followed by zero or more inter frames. It
+ contains one or more instances of ``FrameEncoder``.
- * ``FrameEncoder`` is used to compress a frame. It contains one
- ``CoolChicEncoder``.
+ * ``frame.py`` implements the class ``FrameEncoder`` is used to compress a
+ frame. To do so, it relies on one ``CoolChicEncoder``.
- * ``CoolChicEncoder`` is the main coding engine of the codec. It is composed
- of latent grids, an auto-regressive module, an upsampling and a synthesis.
+ * ``coolchic.py`` implements the class ``CoolChicEncoder``, which is the main
+ coding engine of the codec. It is composed of latent grids, an
+ auto-regressive module, an upsampling and a synthesis.
- * ``Core`` contains all the modules stated above and required by a
+ * ``core/`` contains all the modules stated above and required by a
``CoolChicEncoder``
.. toctree::
- VideoEncoder