diff --git a/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py b/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py index 6d9a64d..83b7794 100644 --- a/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py +++ b/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py @@ -168,7 +168,7 @@ def _select_device(self): def _load_cnn_model_from_weights(self, weights_path): model = self._get_fasterrcnn_resnet50_fpn(num_classes=2) - model.load_state_dict(tc.load(weights_path, map_location=self.device)) + model.load_state_dict(tc.load(weights_path, map_location=self.device, weights_only=True)) return model.to(self.device)