original paper: http://arxiv.org/abs/1809.02736
This repository is the implementation of training and evaluation code for the model above. It was written in advance to expand this repository through further research.
I wrote this with reference to the following two codes.
First, install pytorch
, cuda toolkit
and a cudnn
with correct versions that fit your GPU(s).
Then, install libraries in the requirements.txt
with this command pip install -r requirements.txt
(Create dataset directory structure below.)
├─data
│ ├─train
│ ├─test
│ └─val
└ ...
I used Flicker2W, DIV2K and CLIC2020 for training. (Flicker2W dataset is sufficient to train the model)
- Flicker2W dataset can be found on liujiaheng's repository
- 'Train data (HR images)' in DIV2K
- 'Training Dataset P' & 'Training Dataset M' in CLIC2020
Data pre-processing for removing JPEG compression artifacts is performed in the training stage automatically with customized Dataset class in basic.py
.
For evaluation, i used 24 images in Kodak24 dataset.
For validation, you can use any dataset and it is not necessary. (It is also a good idea to use about 50 images in the training set.)
The model is Mean-scale hyperprior image compression model using a GMM(Gaussian Mixture Model) for entropy model instead of GSM(Gaussian Scale Mixture model) in J. Balle's paper.
The model has 8 quality hyperparameter lambda, controling the trade-off between distortion and bits.
I used lambda = [64, 128, 256, 512, 1024, 2048, 4096, 8192]
for 8 different model.
4 low quality models use the convolution layers with the number of channnels N=192, M=192 and for 4 high quality models, N=192, M=320
You can train the model with command python train.py
at the root directory, so that train.py creates Solver
class and call the method train
.
Before that, you have to modify the config.py
to suit your purpose.
For training 8 different model, firstly train the highest quality(8) model and perform fine-tuning to other models.
Total training steps (batches): 1400K (until [1100K, 1300K, 1350K, 1400K], training with a learning rate [1e-4, 5e-5, 1e-5, 5e-6, 1e-6]) (it is implemented in the method train
in solver.py
)
For fine-tuning, i used the highest quality model's pre-trained weigths until 900K.
The different number of channels between the high-rate model and low-rate model can be solved for fine-tuning with the model loader method in solver.py
.
You can test the model with command python test.py
at the root directory, so that test.py creates Solver
class and call the method test
.
The test result is saved in result\test.txt
Before that, you have to modify the config.py
to suit your purpose.