diff --git a/test.ipynb b/test.ipynb index 5ea1944..a938818 100755 --- a/test.ipynb +++ b/test.ipynb @@ -55,9 +55,6 @@ "import os\n", "import numpy as np\n", "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.utils.data as data\n", "import yaml\n", "\n", "import matplotlib\n", @@ -71,10 +68,8 @@ "from PIL import Image\n", "from torchvision import transforms, utils\n", "\n", - "from datasets import *\n", - "from nets import *\n", - "from functions import *\n", - "from trainer import *" + "from functions import clip_img\n", + "from trainer import Trainer" ], "execution_count": 0, "outputs": [] @@ -102,7 +97,7 @@ "if not os.path.exists(opts.out_path):\n", " os.makedirs(opts.out_path)\n", "\n", - "config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'))\n", + "config = yaml.safe_load(open('./configs/' + opts.config + '.yaml', 'r'))\n", "img_size = (config['input_w'], config['input_h'])\n", "\n", "# Initialize trainer\n", @@ -110,23 +105,23 @@ "device = torch.device('cuda')\n", "trainer.to(device)\n", "\n", - "# Load pretrained model \n", + "# Load pretrained model\n", "if opts.checkpoint:\n", " trainer.load_checkpoint(opts.checkpoint)\n", "else:\n", " trainer.load_checkpoint(log_dir + 'checkpoint')\n", "\n", + "\n", "def preprocess(img_name):\n", " resize = transforms.Compose([\n", " transforms.Resize(img_size),\n", " transforms.ToTensor()\n", " ])\n", - " normalize = transforms.Normalize(mean=[0.48501961, 0.45795686, 0.40760392], std=[1,1,1])\n", + " normalize = transforms.Normalize(mean=[0.48501961, 0.45795686, 0.40760392], std=[1, 1, 1])\n", " img_pil = Image.open(opts.img_path + img_name)\n", - " img_np = np.array(img_pil)\n", " img = resize(img_pil)\n", " if img.size(0) == 1:\n", - " img = torch.cat((img, img, img), dim = 0)\n", + " img = torch.cat((img, img, img), dim=0)\n", " img = normalize(img)\n", " return img" ], @@ -157,14 +152,14 @@ " image_A = image_A.unsqueeze(0).to(device)\n", "\n", " age_modif = torch.tensor(target_age).unsqueeze(0).to(device)\n", - " image_A_modif = trainer.test_eval(image_A, age_modif, target_age=target_age, hist_trans=True) \n", + " image_A_modif = trainer.test_eval(image_A, age_modif, target_age=target_age, hist_trans=True)\n", " utils.save_image(clip_img(image_A_modif), opts.out_path + img_name.split('.')[0] + '_age_' + str(target_age) + '.jpg')\n", "\n", " # Plot manipulated image\n", " img_out = np.array(Image.open(opts.out_path + img_name.split('.')[0] + '_age_' + str(target_age) + '.jpg'))\n", " plt.axis('off')\n", " plt.imshow(img_out)\n", - " plt.show() " + " plt.show()" ], "execution_count": 0, "outputs": []