Skip to content

Commit

Permalink
added suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aspaul20 committed May 24, 2024
1 parent 7c902d0 commit d43088e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 28 deletions.
14 changes: 2 additions & 12 deletions paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def ocr(
bin: binarize image to black and white. Default is False.
inv: invert image colors. Default is False.
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
slice: use sliding window inference for large images, det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres] (See doc/doc_en/slice_en.md). Default is {}.
"""
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
Expand Down Expand Up @@ -721,18 +722,7 @@ def preprocess_image(_image):
_image = binarize_img(_image)
return _image

if det and rec and not slice:
ocr_res = []
for idx, img in enumerate(imgs):
img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls)
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and rec and slice:
if det and rec:
ocr_res = []
for idx, img in enumerate(imgs):
img = preprocess_image(img)
Expand Down
5 changes: 3 additions & 2 deletions tools/infer/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ def __call__(self, img, cls=True, slice={}):
for slice_crop, v_start, h_start in slice_gen:
dt_boxes, elapse = self.text_detector(slice_crop)
if dt_boxes.size:
dt_boxes[:, :, 0] = dt_boxes[:, :, 0] + h_start
dt_boxes[:, :, 1] = dt_boxes[:, :, 1] + v_start
dt_boxes[:, :, 0] += h_start
dt_boxes[:, :, 1] += v_start
dt_slice_boxes.append(dt_boxes)
elapsed.append(elapse)
dt_boxes = np.concatenate(dt_slice_boxes)

dt_boxes = merge_fragmented(
boxes=dt_boxes,
x_threshold=slice["merge_x_thres"],
Expand Down
29 changes: 15 additions & 14 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,34 +732,36 @@ def slice_generator(image, horizontal_stride, vertical_stride, maximum_slices=50
yield (horizontal_slice, v_start, h_start)


def calculate_box_extents(box):
min_x = box[0][0]
max_x = box[1][0]
min_y = box[0][1]
max_y = box[2][1]
return min_x, max_x, min_y, max_y


def merge_fragmented(boxes, x_threshold=10, y_threshold=10):
merged_boxes = []
visited = set()
merged_counter = 0
for i in range(len(boxes)):
if i in visited:
continue

current_box = boxes[i]
merged_box = [current_box[0], current_box[1], current_box[2], current_box[3]]
min_x, max_x, min_y, max_y = (
current_box[0][0],
current_box[1][0],
current_box[0][1],
current_box[2][1],
)
min_x, max_x, min_y, max_y = calculate_box_extents(current_box)

for j in range(len(boxes)):
if i == j:
continue

compare_box = boxes[j]
compare_min_x, compare_max_x, compare_min_y, compare_max_y = (
compare_box[0][0],
compare_box[1][0],
compare_box[0][1],
compare_box[2][1],
)
(
compare_min_x,
compare_max_x,
compare_min_y,
compare_max_y,
) = calculate_box_extents(compare_box)
if (
abs(min_y - compare_min_y) <= y_threshold
and abs(max_y - compare_max_y) <= y_threshold
Expand All @@ -786,7 +788,6 @@ def merge_fragmented(boxes, x_threshold=10, y_threshold=10):

merged_box[3][0] = new_xmin
merged_box[3][1] = new_ymax
merged_counter += 1
visited.add(j)

merged_boxes.append(merged_box)
Expand Down

0 comments on commit d43088e

Please sign in to comment.