From 3050ef1a969588e272b0d6e71b3fa4c0c61eead2 Mon Sep 17 00:00:00 2001 From: Aleksey Morozov <36787333+amrzv@users.noreply.github.com> Date: Thu, 12 Jan 2023 10:10:01 +0200 Subject: [PATCH 1/2] Fixed yaml loading --- test.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test.ipynb b/test.ipynb index 5ea1944..4ba2bb7 100755 --- a/test.ipynb +++ b/test.ipynb @@ -102,7 +102,7 @@ "if not os.path.exists(opts.out_path):\n", " os.makedirs(opts.out_path)\n", "\n", - "config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'))\n", + "config = yaml.safe_load(open('./configs/' + opts.config + '.yaml', 'r'))\n", "img_size = (config['input_w'], config['input_h'])\n", "\n", "# Initialize trainer\n", From 19096130021624d2a2860e313f28cedfdced7f75 Mon Sep 17 00:00:00 2001 From: Aleksey Morozov <36787333+amrzv@users.noreply.github.com> Date: Thu, 12 Jan 2023 10:20:23 +0200 Subject: [PATCH 2/2] Deleted unused imports --- test.ipynb | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/test.ipynb b/test.ipynb index 4ba2bb7..a938818 100755 --- a/test.ipynb +++ b/test.ipynb @@ -55,9 +55,6 @@ "import os\n", "import numpy as np\n", "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.utils.data as data\n", "import yaml\n", "\n", "import matplotlib\n", @@ -71,10 +68,8 @@ "from PIL import Image\n", "from torchvision import transforms, utils\n", "\n", - "from datasets import *\n", - "from nets import *\n", - "from functions import *\n", - "from trainer import *" + "from functions import clip_img\n", + "from trainer import Trainer" ], "execution_count": 0, "outputs": [] @@ -110,23 +105,23 @@ "device = torch.device('cuda')\n", "trainer.to(device)\n", "\n", - "# Load pretrained model \n", + "# Load pretrained model\n", "if opts.checkpoint:\n", " trainer.load_checkpoint(opts.checkpoint)\n", "else:\n", " trainer.load_checkpoint(log_dir + 'checkpoint')\n", "\n", + "\n", "def preprocess(img_name):\n", " resize = transforms.Compose([\n", " transforms.Resize(img_size),\n", " transforms.ToTensor()\n", " ])\n", - " normalize = transforms.Normalize(mean=[0.48501961, 0.45795686, 0.40760392], std=[1,1,1])\n", + " normalize = transforms.Normalize(mean=[0.48501961, 0.45795686, 0.40760392], std=[1, 1, 1])\n", " img_pil = Image.open(opts.img_path + img_name)\n", - " img_np = np.array(img_pil)\n", " img = resize(img_pil)\n", " if img.size(0) == 1:\n", - " img = torch.cat((img, img, img), dim = 0)\n", + " img = torch.cat((img, img, img), dim=0)\n", " img = normalize(img)\n", " return img" ], @@ -157,14 +152,14 @@ " image_A = image_A.unsqueeze(0).to(device)\n", "\n", " age_modif = torch.tensor(target_age).unsqueeze(0).to(device)\n", - " image_A_modif = trainer.test_eval(image_A, age_modif, target_age=target_age, hist_trans=True) \n", + " image_A_modif = trainer.test_eval(image_A, age_modif, target_age=target_age, hist_trans=True)\n", " utils.save_image(clip_img(image_A_modif), opts.out_path + img_name.split('.')[0] + '_age_' + str(target_age) + '.jpg')\n", "\n", " # Plot manipulated image\n", " img_out = np.array(Image.open(opts.out_path + img_name.split('.')[0] + '_age_' + str(target_age) + '.jpg'))\n", " plt.axis('off')\n", " plt.imshow(img_out)\n", - " plt.show() " + " plt.show()" ], "execution_count": 0, "outputs": []