Skip to content

Commit

Permalink
Attempt to fix torch gpu CI
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 9, 2024
1 parent b2af6d5 commit 8deee17
Showing 1 changed file with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def transform_bounding_boxes(
transformation,
training=True,
):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

def _flip_boxes_horizontal(boxes):
x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)
outputs = self.backend.numpy.concatenate(
Expand All @@ -116,9 +119,6 @@ def _flip_boxes_vertical(boxes):
return outputs

def _transform_xyxy(boxes, box_flips):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

bboxes = boxes["boxes"]
if self.mode in {HORIZONTAL, HORIZONTAL_AND_VERTICAL}:
bboxes = self.backend.numpy.where(
Expand All @@ -132,9 +132,6 @@ def _transform_xyxy(boxes, box_flips):
_flip_boxes_vertical(bboxes),
bboxes,
)

self.backend.reset()

return bboxes

flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)
Expand Down Expand Up @@ -176,6 +173,8 @@ def _transform_xyxy(boxes, box_flips):
width=input_width,
)

self.backend.reset()

return bounding_boxes

def transform_segmentation_masks(
Expand Down

0 comments on commit 8deee17

Please sign in to comment.