From 6bca5440ed036c6b40a82b0ba0d5991159f144b3 Mon Sep 17 00:00:00 2001 From: Boyang Wang Date: Sat, 23 Mar 2024 23:56:49 -0400 Subject: [PATCH] docs: little update --- test_code/inference.py | 9 +++++++-- train_code/train.py | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test_code/inference.py b/test_code/inference.py index 244f5ec..8a904f1 100644 --- a/test_code/inference.py +++ b/test_code/inference.py @@ -2,6 +2,7 @@ This is file is to execute the inference for a single image or a folder input ''' import argparse +import time import os, sys, cv2, shutil, warnings import torch from torchvision.transforms import ToTensor @@ -76,6 +77,7 @@ def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torc # Sample Command # 4x GRL (Default): python test_code/inference.py --model GRL --scale 4 --weight_path pretrained/4x_APISR_GRL_GAN_generator.pth + # 4x RRDB: python test_code/inference.py --model RRDB --scale 4 --weight_path pretrained/4x_APISR_RRDB_GAN_generator.pth # 2x RRDB: python test_code/inference.py --model RRDB --scale 2 --weight_path pretrained/2x_APISR_RRDB_GAN_generator.pth @@ -116,8 +118,9 @@ def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torc elif model == "RRDB": generator = load_rrdb(weight_path, scale=scale) # Can be any size generator = generator.to(dtype=weight_dtype) - + + start = time.time() # Take the input path and do inference if os.path.isdir(store_dir): # If the input is a directory, we will iterate it for filename in sorted(os.listdir(input_dir)): @@ -131,7 +134,9 @@ def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torc output_path = os.path.join(store_dir, filename+"_"+str(scale)+"x.png") # In default, we will automatically use crop to match 4x size super_resolve_img(generator, input_dir, output_path, weight_dtype, crop_for_4x=True) - + end = time.time() + + print("Total inference time spent is ", end-start) diff --git a/train_code/train.py b/train_code/train.py index 9496790..cff6e01 100644 --- a/train_code/train.py +++ b/train_code/train.py @@ -40,6 +40,7 @@ def parse_args(): if not args.auto_resume_closest and not args.auto_resume_best: # Restart tensorboard (delete all things under ./runs) + print("We will remove the log of tensorboard.") if os.path.exists("./runs"): storage_manage() shutil.rmtree("./runs")