Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add training #19

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
205 changes: 205 additions & 0 deletions dataset/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import json
import math
from enum import Enum
from typing import Optional, List, Tuple

import numpy as np
import megengine as mge
import megengine.random
import megengine.functional as F

from pydantic import BaseModel
from megfile import SmartPath, smart_load_from
from megengine.data.dataset import Dataset


class BayerPattern(Enum, str):
RGGB = "RGGB"
BGGR = "BGGR"
GRBG = "GRBG"
GBRG = "GBRG"


class RawImageItem(BaseModel):
path: str
width: int
height: int
black_level: int
white_level: int = 65535
bayer_pattern: BayerPattern
g_mean_01: float


class NoiseProfile(BaseModel):
K: Tuple[float, float]
B: Tuple[float, float, float]
value_scale: float = 959.0


class DataAugOptions(BaseModel):
noise_profile: NoiseProfile
camera_value_scale: float = 959.0
iso_range: Tuple[float, float]
anchor_iso: float = 1600.0
output_shape: Tuple[int, int] = (512, 512) # 512x512x4
target_brighness_range: Tuple[float, float] = (0.02, 0.5)


class CleanRawImages(Dataset):

def __init__(self, *, index_file: Optional[str], data_dir: Optional[SmartPath], opts: DataAugOptions):
"""
Args:
- data_dir: a directory that contains "index.json" and raw images
- index_file: the absolute path to the index file
"""
super().__init__()

assert not (index_file is None and data_dir is None)

if data_dir is None:
index_file = SmartPath(index_file)
else:
assert index_file is None
index_file = data_dir / "index.json"

self.opts = DataAugOptions
self.filelist: List[RawImageItem] = []
with index_file.open() as f:
items = [RawImageItem.parse_obj(x) for x in json.load(f)]
for item in items:
if data_dir is not None:
item.path = str(data_dir / item.path)
self.filelist.append(item)

def __len__(self):
return len(self.filelist)

def random_flip_and_crop(self, img: np.ndarray, src_bayer_pattern: BayerPattern) -> np.ndarray:
"""
Random flip and crop a bayter-patterned image, and normalize the bayer pattern to RGGB.
"""

flip_ud = np.random.rand() > 0.5
flip_lr = np.random.rand() > 0.5

if src_bayer_pattern == BayerPattern.RGGB:
crop_x_offset, crop_y_offset = 0, 0
elif src_bayer_pattern == BayerPattern.GBRG:
crop_x_offset, crop_y_offset = 0, 1
elif src_bayer_pattern == BayerPattern.GRBG:
crop_x_offset, crop_y_offset = 1, 0
elif src_bayer_pattern == BayerPattern.BGGR:
crop_x_offset, crop_y_offset = 1, 1

if flip_lr:
crop_x_offset = (crop_x_offset + 1) % 2
if flip_ud:
crop_y_offset = (crop_y_offset + 1) % 2

H0, W0 = img.shape
tH, tW = self.opts.output_shape

x0, y0 = np.random.randint(0, W0 - tW), np.random.randint(0, H0 - tH)
x0, y0 = x0 // 2 * 2 + crop_x_offset, y0 // 2 * 2 + crop_y_offset

img_crop = img[y0:y0+tH, x0:x0+tW]
if flip_lr:
img_crop = np.flip(img_crop, axis=1)
if flip_ud:
img_crop = np.flip(img_crop, axis=0)

return img_crop

def __getitem__(self, index: int):
item = self.filelist[index]
buf = smart_load_from(item.path)
rawimg = np.fromfile(buf, dtype=np.uint16).reshape((item.height, item.width))
# random crop to output size
rawimg = self.random_flip_and_crop(rawimg, item.bayer_pattern).astype(np.float32)

raw01 = (rawimg - item.black_level) / (item.white_level - item.black_level)
H, W = raw01.shape
# pixel shuffle to RGGB image
rggb01 = raw01.reshape(H//2, 2, W//2, 2).transpose(0, 2, 1, 3).reshape(H//2, W//2, 4)
return rggb01, np.array(item.g_mean_01)


class NoiseProfileFunc:

def __init__(self, noise_profile: NoiseProfile):
self.polyK = np.poly1d(noise_profile.K)
self.polyB = np.poly1d(noise_profile.B)
self.value_scale = noise_profile.value_scale

def __call__(self, iso, value_scale=959.0):
r = value_scale / self.value_scale
k = self.polyK(iso) * r
b = self.polyB(iso) * r * r

return k, b


class DataAug:

def __init__(self, opts: DataAugOptions):
self.opts = opts
self.noise_func = NoiseProfileFunc(opts.noise_profile)

def transform(self, batch_img01: np.ndarray, batch_g_mean: float) -> Tuple[mge.Tensor, mge.Tensor, mge.Tensor]:
"""
Args:
- img: [-black/camera_value_scale, 1.0]

Returns:
- noisy_img
- iso
"""

batch_imgs = mge.tensor(batch_img01) * self.opts.camera_value_scale
batch_gt = self.brightness_aug(batch_imgs, batch_g_mean)
batch_imgs, batch_iso = self.add_noise(batch_gt)
cvt_k, cvt_b = self.k_sigma(batch_iso)

batch_imgs = batch_imgs * cvt_k + cvt_b
batch_gt = batch_gt * cvt_k + cvt_b
return (batch_imgs, batch_gt, cvt_k)

def k_sigma(self, iso: float) -> Tuple[float, float]:
k, sigma = self.noise_func(iso, value_scale=self.opts.camera_value_scale)
k_a, sigma_a = self.noise_func(self.opts.anchor_iso, value_scale=self.opts.camera_value_scale)

cvt_k = k_a / k
cvt_b = (sigma / (k ** 2) - sigma_a / (k_a ** 2)) * k_a

return cvt_k, cvt_b

def brightness_aug(self, img_batch: mge.Tensor, orig_gmean: List[float]) -> mge.Tensor:
low, high = self.opts.target_brighness_range
N = len(orig_gmean)
btarget = np.exp(np.random.uniform(np.log(low), np.log(high), size=(N, )))
s = np.clip(btarget / orig_gmean, 0.01, 1.0)
return img_batch * s.reshape(-1, 1, 1, 1)

def add_noise(self, img: mge.Tensor) -> Tuple[mge.Tensor, float]:
"""
Args:
- img: [-black, camera_value_scale]

Returns:
- noisy_img
- iso
"""

N = img.shape[0]
isos = np.random.uniform(*self.opts.iso_range, size=(N, ))
k, b = self.noise_func(isos, value_scale=self.opts.camera_value_scale)
k = k.reshape(-1, 1, 1, 1)
b = b.reshape(-1, 1, 1, 1)

shot_noisy = megengine.random.poisson((img / k).clip(0, 1)) * k
read_noisy = megengine.random.normal(size=img.shape) * math.sqrt(b)
noisy = shot_noisy + read_noisy
noisy = F.round(noisy)

return noisy
2 changes: 1 addition & 1 deletion run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from models.net_mge import Network
from utils import RawUtils
from benchmark import BenchmarkLoader, RawMeta
from dataset.benchmark import BenchmarkLoader, RawMeta


class KSigma:
Expand Down
110 changes: 110 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/usr/bin/env python3
import os
import argparse
from pathlib import Path

import numpy as np
import megengine as mge
import megengine.optimizer
import megengine.functional as F
from megengine.data import DataLoader, RandomSampler
from megengine.autodiff import GradManager

from tqdm import tqdm
from loguru import logger

from models.net_mge import Network, get_loss_l1
from dataset.training import CleanRawImages, DataAug, DataAugOptions


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data-aug-config', type=Path)
parser.add_argument('--data-dir', type=Path)
parser.add_argument('--batch-size', default=1, type=int)
parser.add_argument('--ckp-dir', default=Path('./checkpoints'), type=Path)
parser.add_argument('--learning-rate', dest='lr', default=1e-3, type=float)
parser.add_argument('--num-epoch', default=4000, type=int)

args = parser.parse_args()

# Configure loggger
logger.configure(handlers=[dict(
sink=lambda msg: tqdm.write(msg, end=''),
format="[<green>{time:YYYY-MM-DD HH:mm:ss}</green>] [<level>{level}</level>] {message}",
colorize=True
)])

# Create model
net = Network()
# Create optimizer
optimizer = megengine.optimizer.Adam(net.parameters(), lr=args.lr)
# Create GradManager
gm = GradManager().attach(net.parameters())

aug_opts = DataAugOptions.parse_file(args.data_aug_config)
train_aug = DataAug(aug_opts)
train_ds = CleanRawImages(data_dir=args.data_dir, opts=aug_opts)
train_loader = DataLoader(train_ds, sampler=RandomSampler(train_ds, batch_size=args.batch_size, drop_last=True))

# learning rate scheduler
def adjust_learning_rate(opt, epoch, step):
M = len(train_ds) // args.batch_size
T = M * 100
Th = T // 2

# # warm up
# if base_lr > 2e-3 and step < T:
# return 1e-4

if epoch < 3000:
f = 1 - step / (M*3000)
elif epoch < 3000:
f = 0.1
elif epoch < 5000:
f = 0.2
else:
f = 0.1

t = step % T
if t < Th:
f2 = t / Th
else:
f2 = 2 - (t/Th)

lr = f * f2 * args.lr

for pgroup in opt.param_groups:
pgroup["lr"] = lr

return lr

# train step
def train_step(img, gt, norm_k):
with gm:
pred = net(img)
loss = get_loss_l1(pred, gt, norm_k)
gm.backward(loss)
optimizer.step().clear_grad()
return loss

# train loop
global_step = 0
for epoch in range(args.num_epoch):
for bidx, (imgs, g_means) in enumerate(tqdm(train_loader, dynamic_ncols=True)):
imgs, gt, norm_k = train_aug.transform(imgs, g_means)
lr = adjust_learning_rate(optimizer, epoch, global_step)
loss = train_step(imgs, gt, norm_k)

if global_step % 100 == 0:
logger.info(f"clock: {epoch},{bidx}, loss: {loss.item()}, lr: {lr}")

global_step += 1

# save checkpoint
if epoch % 100 == 0:
mge.save(net.state_dict(), args.ckp_dir / f"epoch_{epoch}.pkl")


if __name__ == "__main__":
main()