Skip to content

Commit

Permalink
Add patch support
Browse files Browse the repository at this point in the history
  • Loading branch information
ppizarror committed Aug 13, 2024
1 parent 6818cb5 commit 31751c0
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 43 deletions.
4 changes: 2 additions & 2 deletions MLStructFP_benchmarks/utils/_fp_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def save_list(fn: str, image_list: List['np.ndarray']) -> None:
return
np.savez_compressed(fn, data=np.array(image_list, dtype='uint8')) # .npz

save_list(f'{path}_binary', self._gen._patch_binary)
save_list(f'{path}_photo', self._gen._patch_photo)
save_list(f'{path}_binary', self._gen._gen_binary.patches)
save_list(f'{path}_photo', self._gen._gen_photo.patches)
self._processed_floor.clear()
self._gen.clear()
19 changes: 7 additions & 12 deletions MLStructFP_benchmarks/utils/_fp_patch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ class FloorPatchGenerator(object):
_image_size: int
_img_size: int
_min_binary_area: float
_patch_binary: List['np.ndarray']
_patch_photo: List['np.ndarray']
_patch_size: float
_test_ignored_patches: List[int]
_test_last_added: int
Expand Down Expand Up @@ -118,8 +116,6 @@ def __init__(
self._gen_binary = RectBinaryImage(image_size_px=image_size)
self._gen_photo = RectFloorPhoto(image_size_px=image_size, empty_color=0)
self._min_binary_area = min_binary_area
self._patch_binary = []
self._patch_photo = []
self._patch_size = patch_size
self._test_ignored_patches = []
self._test_last_added = 0
Expand Down Expand Up @@ -181,8 +177,8 @@ def clear(self) -> 'FloorPatchGenerator':
:return: Self
"""
self._patch_binary.clear()
self._patch_photo.clear()
self._gen_binary.patches.clear()
self._gen_photo.patches.clear()
self._test_ignored_patches.clear()
self._test_last_added = 0
gc.collect()
Expand Down Expand Up @@ -227,14 +223,13 @@ def process(self, floor: 'Floor') -> 'FloorPatchGenerator':
self._test_ignored_patches.append(n)
continue

self._patch_binary.append(patch_b)
self._patch_photo.append(patch_p)
self._gen_binary.patches.append(patch_b)
self._gen_photo.patches.append(patch_p)
added += 1

self._test_last_added = added
self._gen_binary.close()
self._gen_photo.close()
self._gen_binary.restore_plot()

return self

Expand All @@ -246,8 +241,8 @@ def plot_patch(self, idx: int, inverse: bool = False) -> None:
:param inverse: If true, plot inversed colors (white as background)
"""
plt.figure(dpi=DEFAULT_PLOT_DPI)
photo = self._patch_photo[idx]
binary = self._patch_binary[idx]
photo = self._gen_photo.patches[idx]
binary = self._gen_binary.patches[idx]
if inverse:
if not self._bw:
photo = 255 - photo
Expand All @@ -265,7 +260,7 @@ def plot_photo(self, idx: int, inverse: bool = False) -> None:
:param inverse: If true, plot inversed colors (white as background)
"""
plt.figure(dpi=DEFAULT_PLOT_DPI)
photo = self._patch_photo[idx]
photo = self._gen_photo.patches[idx]
if inverse:
if not self._bw:
photo = 255 - photo
Expand Down
9 changes: 5 additions & 4 deletions create_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
],
"source": [
"db = DbLoader('../MLSTRUCT-FP/dataset/fp.json')\n",
"db.tabulate(limit=5) # Total: 165"
"db.tabulate(limit=5) # Total: 165"
]
},
{
Expand Down Expand Up @@ -205,7 +205,8 @@
}
],
"source": [
"patchgen = FloorPatchGenerator(image_size=256, patch_size=5, bw=False, delta_x=[-0.25, 0, 0.25], delta_y=[-0.25, 0, 0.25])\n",
"patchgen = FloorPatchGenerator(image_size=256, patch_size=5, bw=False, delta_x=[-0.25, 0, 0.25],\n",
" delta_y=[-0.25, 0, 0.25])\n",
"patchgen.plot_patches(db.floors[0])"
]
},
Expand Down Expand Up @@ -258,7 +259,7 @@
"source": [
"# Create 256x256 px images, using patches of 5x5m, in black and white. With an offset of 25%. NO rotation\n",
"dbexport = FPDatasetGenerator(image_size=256, patch_size=5, bw=True, delta_x=[-0.25, 0, 0.25], delta_y=[-0.25, 0, 0.25])\n",
"_ = dbexport.process_dataset(db=db, path='.data_patches/', rotation_angles=(0, ), num_thread_processes=8)"
"_ = dbexport.process_dataset(db=db, path='.data_patches/', rotation_angles=(0,), num_thread_processes=8)"
]
},
{
Expand All @@ -277,7 +278,7 @@
"outputs": [],
"source": [
"dfp = DataFloorPhoto(path='.data_patches/')\n",
"dfp.assemble_train_test(0.7) # 70% train, 30% test\n",
"dfp.assemble_train_test(0.7) # 70% train, 30% test\n",
"\n",
"# Save the session\n",
"dfp.save_session('.session/no_rot_256_50')"
Expand Down
6 changes: 4 additions & 2 deletions fp_unet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"\n",
"# Set seed for reproducibility\n",
"from numpy.random import seed\n",
"\n",
"seed(1)"
]
},
Expand Down Expand Up @@ -448,6 +449,7 @@
" print(f'IoU: {iou_metric(img_true, model.predict_image(img_in)):.2f}')\n",
" model.plot.plot_predict(img_in, img_true, threshold=False, title=title)\n",
"\n",
"\n",
"def plot_floor(image_size: int, patch_size: float, floor) -> None:\n",
" patchgen = FloorPatchGenerator(image_size=image_size, patch_size=patch_size, bw=True)\n",
" patchgen.plot_patches(floor, photo=1, patches=False, rect=False, axis=False,\n",
Expand All @@ -462,7 +464,7 @@
"\n",
"# noinspection PyProtectedMember\n",
"def plot_full(plan_idx: int, patch_size: float):\n",
" plan_id = data._split[1][plan_idx] # Consider only test floor plans\n",
" plan_id = data._split[1][plan_idx] # Consider only test floor plans\n",
" for f in db.floors:\n",
" if f.id == plan_id:\n",
" plot_floor(image_size=model._image_shape[0], patch_size=patch_size, floor=f)"
Expand Down Expand Up @@ -563,7 +565,7 @@
"ious = []\n",
"for i in tqdm(range(len(test_data['photo']))):\n",
" ious.append(iou_metric(test_data['binary'][i], model.predict_image(test_data['photo'][i])))\n",
"print(sum(ious)/len(ious))\n",
"print(sum(ious) / len(ious))\n",
"\n",
"plt.figure(figsize=(6, 6), dpi=250)\n",
"plt.hist(ious, bins=20, edgecolor='dimgrey', linewidth=0.75, weights=np.ones(len(ious)) / len(ious))\n",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
'Keras <= 2.3.1',
'keras_tqdm <= 2.0.1',
'matplotlib <= 3.5.3',
'MLStructFP >= 0.6.0',
'MLStructFP >= 0.6.1',
'numpy <= 1.18.5',
'Pillow >= 10.4.0',
'plotly >= 5.23.0',
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_fp_patch_generator(self) -> None:
# Test plots
patchgen.plot_patches(floor)
patchgen.plot_patch(0)
self.assertEqual(len(patchgen._patch_photo), 22)
self.assertEqual(len(patchgen._gen_photo.patches), 22)

def test_fp_db_generator(self) -> None:
"""
Expand Down
57 changes: 36 additions & 21 deletions vectorization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
"import os\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"from PIL import Image \n",
"from PIL import Image\n",
"from IPython.display import clear_output\n",
"from matplotlib.pyplot import imshow \n",
"from matplotlib.pyplot import imshow\n",
"import matplotlib.pyplot as plt\n",
"import PIL\n",
"import torch\n",
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"from itertools import islice\n",
Expand All @@ -32,11 +32,12 @@
"\n",
"# Import dvec (https://github.com/Vahe1994/Deep-Vectorization-of-Technical-Drawings)\n",
"import sys\n",
"\n",
"sys.path.append('./dvec/')\n",
"\n",
"from util_files.rendering.cairo import render,render_with_skeleton\n",
"from util_files.data.graphics_primitives import PT_LINE, PT_CBEZIER, PT_QBEZIER\n",
"import util_files.loss_functions.supervised as supervised_loss \n",
"from util_files.rendering.cairo import render, render_with_skeleton\n",
"from util_files.data.graphics_primitives import PT_LINE, PT_CBEZIER, PT_QBEZIER\n",
"import util_files.loss_functions.supervised as supervised_loss\n",
"from util_files.optimization.optimizer.scheduled_optimizer import ScheduledOptimizer\n",
"import util_files.dataloading as dataloading\n",
"from vectorization import load_model\n",
Expand All @@ -47,11 +48,15 @@
"from util_files.evaluation_utils import primitive_to_path_and_crop\n",
"\n",
"# Arguments generator\n",
"import sys; sys.argv=['']; del sys\n",
"import sys;\n",
"\n",
"sys.argv = [''];\n",
"del sys\n",
"import argparse\n",
"\n",
"# Steps\n",
"from merging.utils.merging_functions import tensor_vector_graph_numpy,assemble_vector_patches_lines,save_svg,clip_to_box\n",
"from merging.utils.merging_functions import tensor_vector_graph_numpy, assemble_vector_patches_lines, save_svg, \\\n",
" clip_to_box\n",
"from refinement.our_refinement.refinement_for_lines import render_optimization_hard\n",
"from merging.merging_for_lines import postprocess"
]
Expand Down Expand Up @@ -89,6 +94,7 @@
" parser.add_argument('--max_distance_to_connect', type=int, default=15, help='max_distance_to_connect in pixel')\n",
" return parser.parse_args()\n",
"\n",
"\n",
"def preprocess_image(image):\n",
" patch_height, patch_width = image.shape[1:3]\n",
" image = torch.as_tensor(image).type(torch.float32).reshape(-1, patch_height, patch_width) / 255\n",
Expand All @@ -100,7 +106,8 @@
" _ys = torch.from_numpy(_ys)[None]\n",
" return torch.stack([image, _xs * mask, _ys * mask], dim=1)\n",
"\n",
"def read_data(options, image_type = 'RGB'):\n",
"\n",
"def read_data(options, image_type='RGB'):\n",
" train_transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" ])\n",
Expand All @@ -109,28 +116,30 @@
" image_names = os.listdir(options.data_dir)\n",
" print(image_names)\n",
" for image_name in image_names:\n",
" if (image_name[-4:] != 'jpeg' and image_name[-3:] != 'png' and image_name[-3:] != 'jpg') or image_name[0]=='.':\n",
" if (image_name[-4:] != 'jpeg' and image_name[-3:] != 'png' and image_name[-3:] != 'jpg') or image_name[\n",
" 0] == '.':\n",
" print(image_name[-4:])\n",
" continue\n",
" print(os.getcwd())\n",
" img = train_transform(Image.open(options.data_dir + image_name).convert(image_type))\n",
" print(img.shape)\n",
" img_t = torch.zeros(img.shape[0], img.shape[1] + (32 - img.shape[1] % 32),\n",
" img.shape[2] + (32 - img.shape[2] % 32))\n",
" img.shape[2] + (32 - img.shape[2] % 32))\n",
" img_t[:, :img.shape[1], :img.shape[2]] = img\n",
" dataset.append(img_t)\n",
" options.image_name = image_names\n",
" else:\n",
" img = train_transform(Image.open(options.data_dir + options.image_name).convert(image_type))\n",
" print(img.shape)\n",
" img_t = torch.zeros(img.shape[0], img.shape[1] + (32 - img.shape[1] % 32),\n",
" img.shape[2] + (32 - img.shape[2] % 32))\n",
" img.shape[2] + (32 - img.shape[2] % 32))\n",
" img_t[:, :img.shape[1], :img.shape[2]] = img\n",
" dataset.append(img_t)\n",
" options.image_name = [options.image_name]\n",
"\n",
" return dataset\n",
"\n",
"\n",
"def split_to_patches(rgb, patch_size, overlap=0):\n",
" \"\"\"Separates the input into square patches of specified size.\n",
"\n",
Expand Down Expand Up @@ -194,7 +203,7 @@
"source": [
"device = torch.device('cuda:{}'.format(0))\n",
"prefetch_data = True\n",
"batches_completed_in_epoch=0\n",
"batches_completed_in_epoch = 0\n",
"epoch_size = 20000\n",
"curve_count = 10\n",
"model_type = 'model'\n",
Expand All @@ -204,6 +213,7 @@
"# Load/create model \n",
"model = load_model(model_json_param).to(device)\n",
"\n",
"\n",
"# Load weights\n",
"def serialize(checkpoint):\n",
" model_state_dict = checkpoint['model_state_dict']\n",
Expand All @@ -218,6 +228,7 @@
" del model_state_dict[k]\n",
" return checkpoint\n",
"\n",
"\n",
"checkpoint = serialize(torch.load('dvec/vectorization/models/weights/model_lines.weights'))\n",
"model.load_state_dict(checkpoint['model_state_dict'])"
]
Expand Down Expand Up @@ -325,13 +336,17 @@
" if it_batches > patch_images.shape[0]:\n",
" it_batches = patch_images.shape[0]\n",
" with torch.no_grad():\n",
" if(it_start==0):\n",
" patches_vector = model(patch_images[it_start:it_batches].cuda().float(), options.model_output_count).detach().cpu().numpy()\n",
" if (it_start == 0):\n",
" patches_vector = model(patch_images[it_start:it_batches].cuda().float(),\n",
" options.model_output_count).detach().cpu().numpy()\n",
" else:\n",
" patches_vector = np.concatenate((patches_vector,model(patch_images[it_start:it_batches].cuda().float(), options.model_output_count).detach().cpu().numpy()),axis=0)\n",
" patches_vector = np.concatenate((patches_vector, model(patch_images[it_start:it_batches].cuda().float(),\n",
" options.model_output_count).detach().cpu().numpy()),\n",
" axis=0)\n",
"patches_vector = torch.tensor(patches_vector) * 64\n",
"rendered_image = save_svg(tensor_vector_graph_numpy(torch.tensor(patches_vector), patches_offsets, options), image_tensor.shape[1:],\n",
" options.image_name[0], options.output_dir + 'model_output/')\n",
"rendered_image = save_svg(tensor_vector_graph_numpy(torch.tensor(patches_vector), patches_offsets, options),\n",
" image_tensor.shape[1:],\n",
" options.image_name[0], options.output_dir + 'model_output/')\n",
"plt.figure(dpi=250)\n",
"plt.imshow(rendered_image, 'gray')\n",
"plt.axis('off')\n",
Expand Down Expand Up @@ -405,8 +420,8 @@
"source": [
"# Refinement\n",
"vector_after_opt = render_optimization_hard(patches_rgb, patches_vector, device, options, options.image_name[0])\n",
"rendered_image_opt= save_svg(tensor_vector_graph_numpy(vector_after_opt, patches_offsets, options), image.shape[1:],\n",
" options.image_name[0], options.output_dir + 'diff_rendering_output/')\n",
"rendered_image_opt = save_svg(tensor_vector_graph_numpy(vector_after_opt, patches_offsets, options), image.shape[1:],\n",
" options.image_name[0], options.output_dir + 'diff_rendering_output/')\n",
"plt.figure(dpi=250)\n",
"plt.imshow(rendered_image_opt, 'gray')\n",
"plt.axis('off')"
Expand Down Expand Up @@ -483,7 +498,7 @@
"bounding_boxes = np.array([0, image.shape[2], 0, image.shape[1]])\n",
"primitives = merging_result\n",
"paths = list(filter(None, map(primitive_to_path_and_crop,\n",
" zip(primitives[:,:5], bounding_boxes[None].repeat(primitives.shape[0],axis=0)))))\n",
" zip(primitives[:, :5], bounding_boxes[None].repeat(primitives.shape[0], axis=0)))))\n",
"width = Pixels(image.shape[2])\n",
"height = Pixels(image.shape[1])\n",
"view_size = width, height\n",
Expand Down

0 comments on commit 31751c0

Please sign in to comment.