diff --git a/AMASSS/AMASSS.py b/AMASSS/AMASSS.py
index e1750a0..d66d002 100644
--- a/AMASSS/AMASSS.py
+++ b/AMASSS/AMASSS.py
@@ -14,9 +14,80 @@
import shutil
import vtk, qt, slicer
from slicer.ScriptedLoadableModule import *
-from slicer.util import VTKObservationMixin
+from slicer.util import VTKObservationMixin, pip_install
import webbrowser
+import pkg_resources
+
+def check_lib_installed(lib_name, required_version=None):
+ '''
+ Check if the library with the good version (if needed) is already installed in the slicer environment
+ input: lib_name (str) : name of the library
+ required_version (str) : required version of the library (if None, any version is accepted)
+ output: bool : True if the library is installed with the good version, False otherwise
+ '''
+
+ try:
+ installed_version = pkg_resources.get_distribution(lib_name).version
+ # check if the version is the good one - if required_version != None it's considered as a True
+ if required_version and installed_version != required_version:
+ return False
+ else:
+ return True
+ except pkg_resources.DistributionNotFound:
+ return False
+
+# import csv
+
+def install_function(list_libs:list):
+ '''
+ Test the necessary libraries and install them with the specific version if needed
+ User is asked if he wants to install/update-by changing his environment- the libraries with a pop-up window
+ '''
+ libs = list_libs
+ libs_to_install = []
+ libs_to_update = []
+ for lib, version in libs:
+ if not check_lib_installed(lib, version):
+ try:
+ # check if the library is already installed
+ if pkg_resources.get_distribution(lib).version:
+ libs_to_update.append((lib, version))
+ except:
+ libs_to_install.append((lib, version))
+
+ if libs_to_install or libs_to_update:
+ message = "The following changes are required for the libraries:\n"
+
+ #Specify which libraries will be updated with a new version
+ #and which libraries will be installed for the first time
+ if libs_to_update:
+
+ message += "\nLibraries to update (version mismatch):\n"
+ message += "\n".join([f"{lib} (current: {pkg_resources.get_distribution(lib).version}) -> {version}" for lib, version in libs_to_update])
+
+ if libs_to_install:
+ message += "\nLibraries to install:\n"
+ message += "\n".join([f"{lib}=={version}" if version else lib for lib, version in libs_to_install])
+
+ message += "\n\nDo you agree to modify these libraries? Doing so could cause conflicts with other installed Extensions."
+ message += "\n\n (If you are using other extensions, consider downloading another Slicer to use AutomatedDentalTools exclusively.)"
+
+ user_choice = slicer.util.confirmYesNoDisplay(message)
+
+ if user_choice:
+ for lib, version in libs_to_install:
+ lib_version = f'{lib}=={version}' if version else lib
+ pip_install(lib_version)
+
+ for lib, version in libs_to_update:
+ lib_version = f'{lib}=={version}' if version else lib
+ pip_install(lib_version)
+ return True
+ else :
+ return False
+ else:
+ return True
#region ========== FUNCTIONS ==========
def GetSegGroup(group_landmark):
@@ -273,6 +344,9 @@ def setup(self):
self.ui.SaveFolderLineEdit.setHidden(True)
self.ui.PredictFolderLabel.setHidden(True)
+ # Checkbox to enable usage of CPU memory in case GPU memory is not enough
+ self.ui.host_memory.setChecked(False)
+
self.ui.label_4.setVisible(False)
self.ui.horizontalSliderCPU.setVisible(False)
self.ui.spinBoxCPU.setVisible(False)
@@ -428,6 +502,10 @@ def isSegmentInputFunction(self,SegInput):
self.ui.horizontalSliderGPU.setVisible(not SegInput)
self.ui.spinBoxGPU.setVisible(not SegInput)
+ self.ui.label_CPU_use.setVisible(not SegInput)
+ self.ui.host_memory.setVisible(not SegInput)
+
+
# self.ui..setVisible(not self.isSegmentInput)
@@ -596,7 +674,18 @@ def onCPUSpinbox(self):
#region == RUN ==
def onPredictButton(self):
-
+ import platform
+ # first, install the required libraries and their version
+ list_libs = [('torch', None),('torchvision', None),('torchaudio',None),('itk', None), ('dicom2nifti', None), ('monai', '0.7.0'),('einops',None),('nibabel',None),('connected-components-3d','3.9.1')]
+
+ if platform.system() == "Windows":
+ list_libs= [('torch', 'cu118'),('torchvision', 'cu118'),('torchaudio','cu118'),('itk', None), ('dicom2nifti', None), ('monai', '0.7.0'),('einops',None),('nibabel',None),('connected-components-3d','3.9.1')]
+
+ libs_installation = install_function(list_libs)
+ if not libs_installation:
+ qt.QMessageBox.warning(self.parent, 'Warning', 'The module will not work properly without the required libraries.\nPlease install them and try again.')
+ return # stop the function
+
ready = True
if self.folder_as_input:
@@ -616,9 +705,6 @@ def onPredictButton(self):
return
# scan_folder = self.ui.lineEditScanPath.text
-
-
-
# self.input_path = '/home/luciacev/Desktop/REQUESTED_SEG/BAMP_SegPred'
# self.input_path = '/home/luciacev/Desktop/TEST_SEG/TEMP/AnaJ_Scan_T1_OR.gipl.gz'
# self.model_folder = '/home/luciacev/Desktop/Maxime_Gillot/Data/AMASSS/FULL_FACE_MODELS'
@@ -687,7 +773,7 @@ def onPredictButton(self):
param["prediction_ID"] = self.ui.SaveId.text
param["gpu_usage"] = self.ui.spinBoxGPU.value
- param["cpu_usage"] = self.ui.spinBoxCPU.value
+ param["host_memory"] = self.ui.host_memory.isChecked()
documentsLocation = qt.QStandardPaths.DocumentsLocation
@@ -755,7 +841,6 @@ def UpdateRunBtn(self):
def UpdateProgressBar(self,progress):
# print("UpdateProgressBar")
-
if progress == 200:
self.prediction_step += 1
@@ -775,6 +860,7 @@ def UpdateProgressBar(self,progress):
if progress == 100:
+ # self.prediction_step += 1
if self.prediction_step == 1:
# self.progressBar.setValue(self.progress)
@@ -785,6 +871,7 @@ def UpdateProgressBar(self,progress):
self.ui.PredScanLabel.setText(f"Ouput generated for segmentation : {self.progress} / {self.scan_count}")
if self.prediction_step == 2:
+
# self.progressBar.setValue(self.progress)
self.ui.PredSegProgressBar.setValue(self.progress)
self.ui.PredSegLabel.setText(f"Segmented structures : {self.progress} / {self.total_seg_progress}")
@@ -794,7 +881,7 @@ def UpdateProgressBar(self,progress):
def onProcessUpdate(self,caller,event):
-
+
# print(caller.GetProgress(),caller.GetStatus())
# self.ui.TimerLabel.setText(f"Time : {self.startTime:.2f}s")
diff --git a/AMASSS/Resources/UI/AMASSS.ui b/AMASSS/Resources/UI/AMASSS.ui
index 435f347..86798c7 100644
--- a/AMASSS/Resources/UI/AMASSS.ui
+++ b/AMASSS/Resources/UI/AMASSS.ui
@@ -47,7 +47,7 @@
-
- Input Type
+ Input Modality
@@ -68,7 +68,7 @@
-
- Scan's folder
+ Select directory
@@ -539,6 +539,17 @@
+ -
+
+
-
+
+
+ Use CPU memory
+
+
+
+
+
-
@@ -622,6 +633,16 @@
+ -
+
+
+ Use host memory with a performance penalty. Enable if CUDA runs out of memory.
+
+
+
+
+
+
-
@@ -684,6 +705,9 @@
+ -
+
+
-
diff --git a/AMASSS_CLI/AMASSS_CLI.py b/AMASSS_CLI/AMASSS_CLI.py
index ae7442c..9468305 100644
--- a/AMASSS_CLI/AMASSS_CLI.py
+++ b/AMASSS_CLI/AMASSS_CLI.py
@@ -21,58 +21,28 @@
import sys
import platform
-# try:
-# import argparse
-# except ImportError:
-# pip_install('argparse')
-# import argparse
-
-
-# print(sys.argv)
+import torch
+import dicom2nifti
+import itk
+import cc3d
-from slicer.util import pip_install, pip_uninstall
-
-# from slicer.util import pip_uninstall
-# # pip_uninstall('torch torchvision torchaudio')
+import SimpleITK as sitk
+import vtk
+import numpy as np
-# pip_uninstall('monai')
-# try :
-# import logic
+# try:
+# import torch
# except ImportError:
+# if platform.system() == "Windows":
+# pip_install('torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -q')
+# else:
+# pip_install('torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 -q')
+# import torch
-try:
- import torch
-except ImportError:
- if platform.system() == "Windows":
- pip_install('torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -q')
- else:
- pip_install('torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 -q')
- import torch
-
-try:
- import nibabel
-except ImportError:
- pip_install('nibabel -q')
- import nibabel
-
-try:
- import einops
-except ImportError:
- pip_install('einops -q')
- import einops
-
-try:
- import dicom2nifti
-except ImportError:
- pip_install('dicom2nifti -q')
- import dicom2nifti
-
#region try import
-pip_uninstall('monai -q')
-pip_install('monai==0.7.0 -q')
from monai.networks.nets import UNETR
from monai.data import (
@@ -92,26 +62,8 @@
from monai.inferers import sliding_window_inference
-import SimpleITK as sitk
-import vtk
-import numpy as np
-try :
- import itk
-except ImportError:
- pip_install('itk -q')
- import itk
-
-
-try:
- import cc3d
-except ImportError:
- pip_install('connected-components-3d==3.9.1 -q')
- import cc3d
-
- #endregion
-
-
-
+# pip_install('connected-components-3d==3.9.1 -q') #Could connected-components-3d be replaced with itk.connected_component_image_filter or itk.scalar_connected_component_image_filter
+
# endregion
#region Global variables
@@ -952,7 +904,8 @@ def main(args):
prediction_segmentation = {}
-
+ #Get as much memory as possible by cleaning the cache before the 2nd loop
+ torch.cuda.empty_cache()
for model_id,model_path in models_to_use.items():
@@ -972,9 +925,23 @@ def main(args):
net.load_state_dict(torch.load(model_path,map_location=DEVICE))
net.eval()
-
- val_outputs = sliding_window_inference(input_img, cropSize, args["nbr_GPU_worker"], net,overlap=args["precision"])
-
+ ## Should avoid error "CUDA OUT OF MEMORY"
+ # thanks to sw_device = DEVICE, device=torch.device('cpu') - see the documentation of sliding_window_inference
+
+ if args["host_memory"]=="True":
+ device_memory = torch.device('cpu')
+ else:
+ device_memory = DEVICE
+
+ try:
+ val_outputs = sliding_window_inference(input_img, cropSize, args["nbr_GPU_worker"], net,overlap=args["precision"],
+ sw_device= DEVICE, device=device_memory)
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ print("Error: CUDA out of memory. You can try running again by enabling CPU usage.")
+ else:
+ raise
+
pred_data = torch.argmax(val_outputs, dim=1).detach().cpu().type(torch.int16)
segmentations = pred_data.permute(0,3,2,1)
@@ -1006,7 +973,8 @@ def main(args):
sys.stdout.flush()
time.sleep(0.5)
-
+ # Clear the cache of GPU memory after the loop
+ torch.cuda.empty_cache()
#endregion
# print(f"""{1}""")
@@ -1147,7 +1115,7 @@ def main(args):
"vtk_smooth": int(sys.argv[10]),
"prediction_ID": sys.argv[11],
"nbr_GPU_worker": int(sys.argv[12]),
- "nbr_CPU_worker": int(sys.argv[13]),
+ "host_memory": sys.argv[13],
"temp_fold" : sys.argv[14],
"isSegmentInput" : sys.argv[15] == "true",
"isDCMInput": sys.argv[16] == "true",
diff --git a/AMASSS_CLI/AMASSS_CLI.xml b/AMASSS_CLI/AMASSS_CLI.xml
index d0dedff..d50ab21 100644
--- a/AMASSS_CLI/AMASSS_CLI.xml
+++ b/AMASSS_CLI/AMASSS_CLI.xml
@@ -99,12 +99,12 @@
Number of GPU to use
-
- cpu_usage
-
+
+ host_memory
+
12
- Number of CPU to use
-
+ Switch to CPU memory if True
+
diff --git a/Testing/Temporary/CTestCostData.txt b/Testing/Temporary/CTestCostData.txt
new file mode 100644
index 0000000..ed97d53
--- /dev/null
+++ b/Testing/Temporary/CTestCostData.txt
@@ -0,0 +1 @@
+---
diff --git a/Testing/Temporary/LastTest.log b/Testing/Temporary/LastTest.log
new file mode 100644
index 0000000..6edebf3
--- /dev/null
+++ b/Testing/Temporary/LastTest.log
@@ -0,0 +1,3 @@
+Start testing: Jan 11 08:38 EST
+----------------------------------------------------------
+End testing: Jan 11 08:38 EST