diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 67e97a10..864e7911 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -6,8 +6,8 @@
"context":".."
},
- "workspaceFolder": "/home/ws/uavf_2024",
- "workspaceMount": "source=${localWorkspaceFolder},target=/home/ws/uavf_2024,type=bind",
+ "workspaceFolder": "/home/ws/libuavf_2024",
+ "workspaceMount": "source=${localWorkspaceFolder},target=/home/ws/libuavf_2024,type=bind",
"containerEnv": {
// For x86_64
"DISPLAY": "unix:0",
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8caa5b04..bfdc19a9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.5)
-project(uavf_2024)
+project(libuavf_2024)
# Default to C99
@@ -43,7 +43,7 @@ set(srv_files
"srv/TakePicture.srv"
)
-rosidl_generate_interfaces(uavf_2024
+rosidl_generate_interfaces(libuavf_2024
${msg_files}
${srv_files}
)
@@ -55,15 +55,11 @@ install(PROGRAMS
scripts/demo_dropzone_planner.py
scripts/trajectory_planner_node.py
scripts/waypoint_tracker_node.py
+ scripts/demo_imaging_node.py
scripts/imaging_node.py
scripts/mock_imaging_node.py
DESTINATION lib/${PROJECT_NAME}
)
-install(DIRECTORY
- uavf_2024/gnc
- uavf_2024/imaging
- DESTINATION lib/uavf_2024/libuavf_2024
-)
install(DIRECTORY include/${PROJECT_NAME}/
DESTINATION include/${PROJECT_NAME}
)
diff --git a/Dockerfile b/Dockerfile
index 8ff45d00..f6253016 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -159,6 +159,8 @@ RUN usermod -a -G dialout qgc
# Sourcing script at runtime
COPY .devcontainer/bashrc_setup.sh /usr/local/bin/bashrc_setup.sh
RUN chmod 777 /usr/local/bin/bashrc_setup.sh
-RUN /usr/local/bin/bashrc_setup.sh
+#CMD ["/usr/local/bin/bashrc_setup.sh"]
+# TOTAL HACK, just doing this to avoid an entire Docker rebuild
+RUN mv /home/ws/uavf_2024 /home/ws/libuavf_2024
CMD ["/bin/bash"]
diff --git a/README.md b/README.md
index f17bf6d2..3f566359 100644
--- a/README.md
+++ b/README.md
@@ -51,7 +51,7 @@ Do this AFTER doing `pip install -e .` If you do that after, it'll overwrite the
sudo apt-get install libjpeg-dev zlib1g-dev libpython3-dev libopenblas-dev libavcodec-dev libavformat-dev libswscale-dev
git clone --branch v0.16.1 https://github.com/pytorch/vision torchvision
cd torchvision
-export BUILD_VERSION = 0.16.1
+export BUILD_VERSION=0.16.1
python3 setup.py install --user
```
diff --git a/include/uavf_2024/README.md b/include/libuavf_2024/README.md
similarity index 100%
rename from include/uavf_2024/README.md
rename to include/libuavf_2024/README.md
diff --git a/include/uavf_2024/gnc/README.md b/include/libuavf_2024/gnc/README.md
similarity index 100%
rename from include/uavf_2024/gnc/README.md
rename to include/libuavf_2024/gnc/README.md
diff --git a/include/uavf_2024/imaging/README.md b/include/libuavf_2024/imaging/README.md
similarity index 100%
rename from include/uavf_2024/imaging/README.md
rename to include/libuavf_2024/imaging/README.md
diff --git a/launches/imaging_demo.launch b/launches/imaging_demo.launch
new file mode 100644
index 00000000..ca829736
--- /dev/null
+++ b/launches/imaging_demo.launch
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/msg/TargetDetection.msg b/msg/TargetDetection.msg
index 4a6d1543..a78aa6ae 100644
--- a/msg/TargetDetection.msg
+++ b/msg/TargetDetection.msg
@@ -4,7 +4,7 @@ float64 x
float64 y
float64 z
-float64[8] shape_conf
+float64[13] shape_conf
# should sum to 1
# indices in order:
# circle, semicircle, quarter circle, triangle, rectangle, pentagon, star, cross
diff --git a/package.xml b/package.xml
index 35d28a9b..73cc1dc5 100644
--- a/package.xml
+++ b/package.xml
@@ -1,7 +1,7 @@
- uavf_2024
+ libuavf_2024
0.0.0
TODO: Package description
Herpderk
diff --git a/scripts/camera_test_node.py b/scripts/camera_test_node.py
new file mode 100644
index 00000000..2636903a
--- /dev/null
+++ b/scripts/camera_test_node.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+
+import rclpy
+from rclpy.node import Node
+from uavf_2024.imaging import Camera
+from time import sleep
+
+class CameraTestNode(Node):
+ def __init__(self) -> None:
+ super().__init__('imaging_node')
+ self.camera = Camera()
+ self.camera.setAbsoluteZoom(1)
+ self.camera.cam.requestAbsolutePosition(0, 0)
+ sleep(2)
+
+ def loop(self):
+ while True:
+ print(self.camera.cam.getAttitude())
+ sleep(1/10)
+ # if self.camera.cam.requestGimbalAttitude():
+ # attitude = self.camera.cam.getAttitude()
+ # self.get_logger().info(str(attitude))
+ # else:
+ # self.get_logger().info(":(")
+ # sleep(1)
+
+
+def main(args=None) -> None:
+ print('Starting imaging node...')
+ rclpy.init(args=args)
+ node = CameraTestNode()
+ node.loop()
+
+if __name__ == '__main__':
+ try:
+ main()
+ except Exception as e:
+ print(e)
\ No newline at end of file
diff --git a/scripts/demo_commander_node.py b/scripts/demo_commander_node.py
index 07835494..4704c0ba 100755
--- a/scripts/demo_commander_node.py
+++ b/scripts/demo_commander_node.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-from libuavf_2024.gnc.commander_node import CommanderNode
+from uavf_2024.gnc.commander_node import CommanderNode
import mavros_msgs.msg
import mavros_msgs.srv
import rclpy
@@ -9,7 +9,7 @@
from threading import Thread
import sys
-# Command to run: ros2 run uavf_2024 demo_commander_node.py /home/ws/uavf_2024/uavf_2024/gnc/data/TEST_MISSION /home/ws/uavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY /home/ws/uavf_2024/uavf_2024/gnc/data/PAYLOAD_LIST 12 9
+# Command to run: ros2 run libuavf_2024 demo_commander_node.py /home/ws/uavf_2024/uavf_2024/gnc/data/TEST_MISSION /home/ws/uavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY /home/ws/uavf_2024/uavf_2024/gnc/data/PAYLOAD_LIST 12 9
if __name__ == '__main__':
rclpy.init()
diff --git a/scripts/demo_dropzone_planner.py b/scripts/demo_dropzone_planner.py
index 4cd21703..5a391761 100644
--- a/scripts/demo_dropzone_planner.py
+++ b/scripts/demo_dropzone_planner.py
@@ -1,14 +1,14 @@
#!/usr/bin/env python3
-from libuavf_2024.gnc.util import read_gps
-from libuavf_2024.gnc.dropzone_planner import DropzonePlanner
-from libuavf_2024.gnc.commander_node import CommanderNode
+from uavf_2024.gnc.util import read_gps
+from uavf_2024.gnc.dropzone_planner import DropzonePlanner
+from uavf_2024.gnc.commander_node import CommanderNode
from threading import Thread
import rclpy
import argparse
-# example: ros2 run uavf_2024 demo_dropzone_planner.py /home/ws/uavf_2024/uavf_2024/gnc/data/TEST_MISSION /home/ws/uavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY 0 0 0 0 12 9
+# example: ros2 run libuavf_2024 demo_dropzone_planner.py /home/ws/uavf_2024/uavf_2024/gnc/data/TEST_MISSION /home/ws/uavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY 0 0 0 0 12 9
if __name__ == '__main__':
rclpy.init()
diff --git a/scripts/demo_imaging_node.py b/scripts/demo_imaging_node.py
new file mode 100644
index 00000000..27546cb3
--- /dev/null
+++ b/scripts/demo_imaging_node.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python3
+from libuavf_2024.srv import TakePicture
+import rclpy
+from rclpy.node import Node
+from time import sleep
+
+class DemoImagingClient(Node):
+
+ def __init__(self):
+ super().__init__('demo_imaging_client')
+ self.get_logger().info("Initializing Client")
+ self.cli = self.create_client(TakePicture, 'imaging_service')
+ while not self.cli.wait_for_service(timeout_sec=1.0):
+ self.get_logger().info('service not available, waiting again...')
+ self.get_logger().info("Finished intializing client")
+ self.req = TakePicture.Request()
+ sleep(5)
+ res = self.send_request()
+ self.get_logger().info(str(res.detections))
+
+ def send_request(self):
+ self.get_logger().info("Sending request")
+ self.future = self.cli.call_async(self.req)
+ rclpy.spin_until_future_complete(self, self.future)
+ res = self.future.result()
+ return res
+
+# Command to run: ros2 run uavf_2024 demo_imaging_node.py
+
+if __name__ == '__main__':
+ print('Starting client node...')
+ rclpy.init()
+ node = DemoImagingClient()
+ rclpy.spin(node)
+ node.destroy_node()
+ rclpy.shutdown()
\ No newline at end of file
diff --git a/scripts/imaging_node.py b/scripts/imaging_node.py
index f5567d90..ccba8510 100755
--- a/scripts/imaging_node.py
+++ b/scripts/imaging_node.py
@@ -2,28 +2,57 @@
import rclpy
from rclpy.node import Node
-from uavf_2024.msg import TargetDetection
-from uavf_2024.srv import TakePicture
+from libuavf_2024.msg import TargetDetection
+from libuavf_2024.srv import TakePicture
+from uavf_2024.imaging import Camera, ImageProcessor, Localizer
import numpy as np
+from time import strftime, time
class ImagingNode(Node):
def __init__(self) -> None:
super().__init__('imaging_node')
self.imaging_service = self.create_service(TakePicture, 'imaging_service', self.imaging_callback)
-
+ self.camera = Camera()
+ self.camera.setAbsoluteZoom(1)
+ self.image_processor = ImageProcessor(f'logs/{strftime("%m-%d %H:%M")}')
+ self.localizer = Localizer(30, (1920, 1080))
+ self.get_logger().info("Finished initializing imaging node")
+
def imaging_callback(self, request, response: list[TargetDetection]):
- response.detections = [
- TargetDetection(
- timestamp = 69420,
- x = 1.0,
- y = 2.0,
- z = 3.0,
- shape_conf = np.zeros(8).tolist(),
- letter_conf = np.zeros(36).tolist(),
- shape_color_conf = np.zeros(8).tolist(),
- letter_color_conf = np.zeros(8).tolist(),
+ self.get_logger().info("Received Request")
+ self.camera.request_center()
+ self.camera.request_autofocus()
+ img = self.camera.take_picture()
+ timestamp = time()
+
+ self.get_logger().info("Picture taken")
+
+ detections = self.image_processor.process_image(img)
+
+ self.get_logger().info("Images processed")
+
+ cam_pose = np.array([0,0,0,0,0,0])
+ preds_3d = [self.localizer.prediction_to_coords(d, cam_pose) for d in detections]
+
+ self.get_logger().info("Localization finished")
+
+ response.detections = []
+
+ for i, p in enumerate(preds_3d):
+ t = TargetDetection(
+ timestamp = int(timestamp*1000),
+ x = p.position[0],
+ y = p.position[1],
+ z = p.position[2],
+ shape_conf = p.description.shape_probs.tolist(),
+ letter_conf = p.description.letter_probs.tolist(),
+ shape_color_conf = p.description.shape_col_probs.tolist(),
+ letter_color_conf = p.description.letter_col_probs.tolist()
)
- ]
+
+ response.detections.append(t)
+
+ self.get_logger().info("Returning Response")
return response
def main(args=None) -> None:
diff --git a/scripts/mock_imaging_node.py b/scripts/mock_imaging_node.py
index c562d64f..7274cb4c 100644
--- a/scripts/mock_imaging_node.py
+++ b/scripts/mock_imaging_node.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-# example usage: ros2 run uavf_2024 mock_imaging_node.py /home/ws/uavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY 12 9
+# example usage: ros2 run libuavf_2024 mock_imaging_node.py /home/ws/libuavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY 12 9
# generates and mocks 5 unique targets
@@ -11,11 +11,11 @@
import rclpy
from rclpy.node import Node
-from uavf_2024.msg import TargetDetection
-from uavf_2024.srv import TakePicture
+from libuavf_2024.msg import TargetDetection
+from libuavf_2024.srv import TakePicture
from geometry_msgs.msg import PoseStamped
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy, HistoryPolicy
-from libuavf_2024.gnc.util import read_gps, convert_delta_gps_to_local_m
+from uavf_2024.gnc.util import read_gps, convert_delta_gps_to_local_m
from scipy.spatial.transform import Rotation as R
import numpy as np
import argparse
diff --git a/sim_instructions.md b/sim_instructions.md
index 02cf6a5e..89e79a8c 100644
--- a/sim_instructions.md
+++ b/sim_instructions.md
@@ -39,12 +39,12 @@ cd /home/ws && colcon build --merge-install && source install/setup.bash
Launch the mock imaging node:
```
-ros2 run uavf_2024 mock_imaging_node.py /home/ws/uavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY 12 9
+ros2 run libuavf_2024 mock_imaging_node.py /home/ws/libuavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY 12 9
```
Launch the demo commander node:
```
-ros2 run uavf_2024 demo_commander_node.py /home/ws/uavf_2024/uavf_2024/gnc/data/TEST_MISSION /home/ws/uavf_2024/uavf_2024/gnc/data/AIRDROP_BOUNDARY /home/ws/uavf_2024/uavf_2024/gnc/data/PAYLOAD_LIST 12 9
+ros2 run libuavf_2024 demo_commander_node.py /home/ws/libuavf_2024/uavf_2024/gnc/data/TEST_MISSION /home/ws/uavf_2024/libuavf_2024/gnc/data/AIRDROP_BOUNDARY /home/ws/libuavf_2024/uavf_2024/gnc/data/PAYLOAD_LIST 12 9
```
This will execute one lap of the mission in SITL.
diff --git a/siyi_sdk b/siyi_sdk
index 9c6df7e5..9c69a5c1 160000
--- a/siyi_sdk
+++ b/siyi_sdk
@@ -1 +1 @@
-Subproject commit 9c6df7e5950e2165897f7f7b3642eaf784e6bcf0
+Subproject commit 9c69a5c1f73ef4d6b61b249f68be3d1c66e5052b
diff --git a/srv/TakePicture.srv b/srv/TakePicture.srv
index 78ac3d8d..a4b17cba 100644
--- a/srv/TakePicture.srv
+++ b/srv/TakePicture.srv
@@ -1,2 +1,2 @@
---
-uavf_2024/TargetDetection[] detections
\ No newline at end of file
+libuavf_2024/TargetDetection[] detections
\ No newline at end of file
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/imaging/__init__.py b/tests/imaging/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/imaging/image_processor_tests.py b/tests/imaging/image_processor_tests.py
index 7a46336f..33aa9218 100644
--- a/tests/imaging/image_processor_tests.py
+++ b/tests/imaging/image_processor_tests.py
@@ -132,8 +132,11 @@ def test_runs_without_crashing(self):
@profiler
def test_benchmark_fullsize_images(self):
- image_processor = ImageProcessor()
- sample_input = Image.from_file(f"{CURRENT_FILE_PATH}/imaging_data/fullsize_dataset/images/image0.png")
+ image_processor = ImageProcessor(
+ shape_batch_size=20,
+ letter_batch_size=30
+ )
+ sample_input = Image.from_file(f"{CURRENT_FILE_PATH}/imaging_data/fullsize_dataset/images/5k.png")
times = []
N_runs = 10
for i in tqdm(range(N_runs)):
@@ -143,8 +146,8 @@ def test_benchmark_fullsize_images(self):
times.append(elapsed)
print(f"Fullsize image benchmarks (average of {N_runs} runs):")
print(f"Avg: {np.mean(times)}, StdDev: {np.std(times)}")
- lstats = profiler.get_stats()
- line_profiler.show_text(lstats.timings, lstats.unit)
+ # lstats = profiler.get_stats()
+ # line_profiler.show_text(lstats.timings, lstats.unit)
def test_no_duplicates(self):
# Given 5 identified bounding boxes, removes duplicate bounding box using nms such that there are 4 bounding boxes left
diff --git a/tests/imaging/imaging_data/fullsize_dataset/images/1080p.png b/tests/imaging/imaging_data/fullsize_dataset/images/1080p.png
new file mode 100644
index 00000000..531eacd6
Binary files /dev/null and b/tests/imaging/imaging_data/fullsize_dataset/images/1080p.png differ
diff --git a/tests/imaging/imaging_data/fullsize_dataset/images/5k.png b/tests/imaging/imaging_data/fullsize_dataset/images/5k.png
new file mode 100644
index 00000000..88abaf88
Binary files /dev/null and b/tests/imaging/imaging_data/fullsize_dataset/images/5k.png differ
diff --git a/tests/imaging/imaging_data/fullsize_dataset/images/image0.png b/tests/imaging/imaging_data/fullsize_dataset/images/image0.png
deleted file mode 100644
index 42d21942..00000000
Binary files a/tests/imaging/imaging_data/fullsize_dataset/images/image0.png and /dev/null differ
diff --git a/uavf_2024/gnc/commander_node.py b/uavf_2024/gnc/commander_node.py
index 9d0314f8..cf5e77a2 100644
--- a/uavf_2024/gnc/commander_node.py
+++ b/uavf_2024/gnc/commander_node.py
@@ -6,9 +6,9 @@
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy, HistoryPolicy
import sensor_msgs.msg
import geometry_msgs.msg
-import uavf_2024.srv
-from libuavf_2024.gnc.util import read_gps, convert_delta_gps_to_local_m, convert_local_m_to_delta_gps, calculate_turn_angles_deg, read_payload_list
-from libuavf_2024.gnc.dropzone_planner import DropzonePlanner
+import libuavf_2024.srv
+from uavf_2024.gnc.util import read_gps, convert_delta_gps_to_local_m, convert_local_m_to_delta_gps, calculate_turn_angles_deg, read_payload_list
+from uavf_2024.gnc.dropzone_planner import DropzonePlanner
from scipy.spatial.transform import Rotation as R
import time
diff --git a/uavf_2024/gnc/dropzone_planner.py b/uavf_2024/gnc/dropzone_planner.py
index cd20f120..ac7c1109 100644
--- a/uavf_2024/gnc/dropzone_planner.py
+++ b/uavf_2024/gnc/dropzone_planner.py
@@ -1,5 +1,5 @@
from typing import List, Tuple
-from libuavf_2024.gnc.util import convert_local_m_to_delta_gps
+from uavf_2024.gnc.util import convert_local_m_to_delta_gps
import numpy as np
import math
diff --git a/uavf_2024/gnc/util.py b/uavf_2024/gnc/util.py
index c685ef10..b2d202b0 100644
--- a/uavf_2024/gnc/util.py
+++ b/uavf_2024/gnc/util.py
@@ -1,4 +1,4 @@
-from libuavf_2024.gnc.payload import Payload
+from uavf_2024.gnc.payload import Payload
from geographiclib.geodesic import Geodesic
import numpy as np
diff --git a/uavf_2024/imaging/__init__.py b/uavf_2024/imaging/__init__.py
index 949be51d..c288c773 100644
--- a/uavf_2024/imaging/__init__.py
+++ b/uavf_2024/imaging/__init__.py
@@ -1,2 +1,6 @@
import line_profiler
-profiler = line_profiler.LineProfiler()
\ No newline at end of file
+profiler = line_profiler.LineProfiler()
+
+from .image_processor import ImageProcessor
+from .localizer import Localizer
+from .camera import Camera
\ No newline at end of file
diff --git a/uavf_2024/imaging/camera.py b/uavf_2024/imaging/camera.py
index 717eeab8..4ae29341 100644
--- a/uavf_2024/imaging/camera.py
+++ b/uavf_2024/imaging/camera.py
@@ -1,6 +1,7 @@
import numpy as np
from time import sleep
from siyi_sdk import SIYISTREAM,SIYISDK
+from uavf_2024.imaging.imaging_types import Image, HWC
import matplotlib.image
class Camera:
@@ -9,15 +10,25 @@ def __init__(self):
self.stream = SIYISTREAM(server_ip = "192.168.144.25", port = 8554,debug=False)
self.stream.connect()
self.cam.connect()
+ #self.cam.requestLockMode()
- def take_picture(self) -> np.ndarray:
+ def take_picture(self) -> Image:
'''
- Returns picture as ndarray with shape (3, width, height)
+ Returns picture as ndarray with shape (height,width,3)
'''
pic = self.stream.get_frame()
- return pic
+ return Image(pic, HWC)
# return np.random.rand(3, 3840, 2160)
+
+ def request_center(self):
+ return self.cam.requestAbsolutePosition(0, 0)
+
+ def request_autofocus(self):
+ return self.cam.requestAutoFocus()
+
+ def setAbsoluteZoom(self, zoom_level: float):
+ return self.cam.setAbsoluteZoom(1)
def disconnect(self):
self.stream.disconnect()
@@ -26,5 +37,5 @@ def disconnect(self):
if __name__ == "__main__":
cam = Camera()
out = cam.take_picture()
- matplotlib.image.imsave("sample_frame.png",out)
+ matplotlib.image.imsave("sample_frame.png",out.get_array().transpose(2,1,0))
cam.disconnect()
\ No newline at end of file
diff --git a/uavf_2024/imaging/image_processor.py b/uavf_2024/imaging/image_processor.py
index b1e4a55e..ce166707 100644
--- a/uavf_2024/imaging/image_processor.py
+++ b/uavf_2024/imaging/image_processor.py
@@ -57,9 +57,12 @@ def nms_process(shape_results: InstanceSegmentationResult, thresh_iou):
class ImageProcessor:
- def __init__(self, debug_path: str = None):
+ def __init__(self, debug_path: str = None, shape_batch_size = 3, letter_batch_size = 5):
'''
Initialize all models here
+
+ `shape_batch_size` is how many tiles we batch up for shape detection inference
+ `letter_batch_size` is how many bounding box crops we batch up for letter classification
'''
self.tile_size = 640
self.letter_size = 128
@@ -69,6 +72,8 @@ def __init__(self, debug_path: str = None):
self.debug_path = debug_path
self.thresh_iou = 0.5
self.num_processed = 0
+ self.shape_batch_size = shape_batch_size
+ self.letter_batch_size = letter_batch_size
def process_image(self, img: Image) -> list[FullPrediction]:
'''
@@ -83,9 +88,9 @@ def process_image(self, img: Image) -> list[FullPrediction]:
shape_results: list[InstanceSegmentationResult] = []
- TILES_BATCH_SIZE = 3
- for tiles in batched(img.generate_tiles(self.tile_size), TILES_BATCH_SIZE):
- temp = self.shape_detector.predict(tiles)
+ all_tiles = img.generate_tiles(self.tile_size)
+ for tiles_batch in batched(all_tiles, self.shape_batch_size):
+ temp = self.shape_detector.predict(tiles_batch)
if temp is not None: shape_results.extend(temp)
shape_results = nms_process(shape_results, self.thresh_iou)
@@ -102,11 +107,9 @@ def process_image(self, img: Image) -> list[FullPrediction]:
self.num_processed += 1
- SHAPES_BATCH_SIZE = 5 # these are small images so we can do a lot at once
-
total_results: list[FullPrediction] = []
# create debug directory for segmentation and classification
- for results in batched(shape_results, SHAPES_BATCH_SIZE):
+ for results in batched(shape_results, self.letter_batch_size):
results: list[InstanceSegmentationResult] = results # type hinting
letter_imgs = []
for shape_res in results: # These are all linear operations so not parallelized (yet)