Skip to content

herok97/MeanScaleHyperpiror

Repository files navigation

Mean-Scale Hyperprior Image Compression Model (w/o Context model)

original paper: http://arxiv.org/abs/1809.02736

Introduction

This repository is the implementation of training and evaluation code for the model above. It was written in advance to expand this repository through further research.

I wrote this with reference to the following two codes.


Environment

First, install pytorch, cuda toolkit and a cudnn with correct versions that fit your GPU(s).

Then, install libraries in the requirements.txt with this command pip install -r requirements.txt

Dataset

(Create dataset directory structure below.)

├─data
│  ├─train
│  ├─test
│  └─val
└ ...

I used Flicker2W, DIV2K and CLIC2020 for training. (Flicker2W dataset is sufficient to train the model)

Data pre-processing for removing JPEG compression artifacts is performed in the training stage automatically with customized Dataset class in basic.py.

For evaluation, i used 24 images in Kodak24 dataset.

For validation, you can use any dataset and it is not necessary. (It is also a good idea to use about 50 images in the training set.)

Model

The model is Mean-scale hyperprior image compression model using a GMM(Gaussian Mixture Model) for entropy model instead of GSM(Gaussian Scale Mixture model) in J. Balle's paper.

The model has 8 quality hyperparameter lambda, controling the trade-off between distortion and bits.

I used lambda = [64, 128, 256, 512, 1024, 2048, 4096, 8192] for 8 different model.

4 low quality models use the convolution layers with the number of channnels N=192, M=192 and for 4 high quality models, N=192, M=320

Training

You can train the model with command python train.py at the root directory, so that train.py creates Solver class and call the method train.

Before that, you have to modify the config.py to suit your purpose.

For training 8 different model, firstly train the highest quality(8) model and perform fine-tuning to other models.

Total training steps (batches): 1400K (until [1100K, 1300K, 1350K, 1400K], training with a learning rate [1e-4, 5e-5, 1e-5, 5e-6, 1e-6]) (it is implemented in the method train in solver.py)

For fine-tuning, i used the highest quality model's pre-trained weigths until 900K.

The different number of channels between the high-rate model and low-rate model can be solved for fine-tuning with the model loader method in solver.py.

Evaluation

You can test the model with command python test.py at the root directory, so that test.py creates Solver class and call the method test.

The test result is saved in result\test.txt

Before that, you have to modify the config.py to suit your purpose.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages