Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: add option to choose host memory #75

Merged
merged 7 commits into from
Jan 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 95 additions & 8 deletions AMASSS/AMASSS.py
Original file line number Diff line number Diff line change
@@ -14,9 +14,80 @@
import shutil
import vtk, qt, slicer
from slicer.ScriptedLoadableModule import *
from slicer.util import VTKObservationMixin
from slicer.util import VTKObservationMixin, pip_install

import webbrowser
import pkg_resources

def check_lib_installed(lib_name, required_version=None):
'''
Check if the library with the good version (if needed) is already installed in the slicer environment
input: lib_name (str) : name of the library
required_version (str) : required version of the library (if None, any version is accepted)
output: bool : True if the library is installed with the good version, False otherwise
'''

try:
installed_version = pkg_resources.get_distribution(lib_name).version
# check if the version is the good one - if required_version != None it's considered as a True
if required_version and installed_version != required_version:
return False
else:
return True
except pkg_resources.DistributionNotFound:
return False

# import csv

def install_function(list_libs:list):
'''
Test the necessary libraries and install them with the specific version if needed
User is asked if he wants to install/update-by changing his environment- the libraries with a pop-up window
'''
libs = list_libs
libs_to_install = []
libs_to_update = []
for lib, version in libs:
if not check_lib_installed(lib, version):
try:
# check if the library is already installed
if pkg_resources.get_distribution(lib).version:
libs_to_update.append((lib, version))
except:
libs_to_install.append((lib, version))

if libs_to_install or libs_to_update:
message = "The following changes are required for the libraries:\n"

#Specify which libraries will be updated with a new version
#and which libraries will be installed for the first time
if libs_to_update:

message += "\nLibraries to update (version mismatch):\n"
message += "\n".join([f"{lib} (current: {pkg_resources.get_distribution(lib).version}) -> {version}" for lib, version in libs_to_update])

if libs_to_install:
message += "\nLibraries to install:\n"
message += "\n".join([f"{lib}=={version}" if version else lib for lib, version in libs_to_install])

message += "\n\nDo you agree to modify these libraries? Doing so could cause conflicts with other installed Extensions."
message += "\n\n (If you are using other extensions, consider downloading another Slicer to use AutomatedDentalTools exclusively.)"

user_choice = slicer.util.confirmYesNoDisplay(message)

if user_choice:
for lib, version in libs_to_install:
lib_version = f'{lib}=={version}' if version else lib
pip_install(lib_version)

for lib, version in libs_to_update:
lib_version = f'{lib}=={version}' if version else lib
pip_install(lib_version)
return True
else :
return False
else:
return True

#region ========== FUNCTIONS ==========
def GetSegGroup(group_landmark):
@@ -273,6 +344,9 @@ def setup(self):
self.ui.SaveFolderLineEdit.setHidden(True)
self.ui.PredictFolderLabel.setHidden(True)

# Checkbox to enable usage of CPU memory in case GPU memory is not enough
self.ui.host_memory.setChecked(False)

self.ui.label_4.setVisible(False)
self.ui.horizontalSliderCPU.setVisible(False)
self.ui.spinBoxCPU.setVisible(False)
@@ -428,6 +502,10 @@ def isSegmentInputFunction(self,SegInput):
self.ui.horizontalSliderGPU.setVisible(not SegInput)
self.ui.spinBoxGPU.setVisible(not SegInput)

self.ui.label_CPU_use.setVisible(not SegInput)
self.ui.host_memory.setVisible(not SegInput)



# self.ui..setVisible(not self.isSegmentInput)

@@ -596,7 +674,18 @@ def onCPUSpinbox(self):

#region == RUN ==
def onPredictButton(self):

import platform
# first, install the required libraries and their version
list_libs = [('torch', None),('torchvision', None),('torchaudio',None),('itk', None), ('dicom2nifti', None), ('monai', '0.7.0'),('einops',None),('nibabel',None),('connected-components-3d','3.9.1')]

if platform.system() == "Windows":
list_libs= [('torch', 'cu118'),('torchvision', 'cu118'),('torchaudio','cu118'),('itk', None), ('dicom2nifti', None), ('monai', '0.7.0'),('einops',None),('nibabel',None),('connected-components-3d','3.9.1')]

libs_installation = install_function(list_libs)
if not libs_installation:
qt.QMessageBox.warning(self.parent, 'Warning', 'The module will not work properly without the required libraries.\nPlease install them and try again.')
return # stop the function

ready = True

if self.folder_as_input:
@@ -616,9 +705,6 @@ def onPredictButton(self):
return

# scan_folder = self.ui.lineEditScanPath.text



# self.input_path = '/home/luciacev/Desktop/REQUESTED_SEG/BAMP_SegPred'
# self.input_path = '/home/luciacev/Desktop/TEST_SEG/TEMP/AnaJ_Scan_T1_OR.gipl.gz'
# self.model_folder = '/home/luciacev/Desktop/Maxime_Gillot/Data/AMASSS/FULL_FACE_MODELS'
@@ -687,7 +773,7 @@ def onPredictButton(self):
param["prediction_ID"] = self.ui.SaveId.text

param["gpu_usage"] = self.ui.spinBoxGPU.value
param["cpu_usage"] = self.ui.spinBoxCPU.value
param["host_memory"] = self.ui.host_memory.isChecked()


documentsLocation = qt.QStandardPaths.DocumentsLocation
@@ -755,7 +841,6 @@ def UpdateRunBtn(self):
def UpdateProgressBar(self,progress):

# print("UpdateProgressBar")

if progress == 200:
self.prediction_step += 1

@@ -775,6 +860,7 @@ def UpdateProgressBar(self,progress):


if progress == 100:
# self.prediction_step += 1

if self.prediction_step == 1:
# self.progressBar.setValue(self.progress)
@@ -785,6 +871,7 @@ def UpdateProgressBar(self,progress):
self.ui.PredScanLabel.setText(f"Ouput generated for segmentation : {self.progress} / {self.scan_count}")

if self.prediction_step == 2:

# self.progressBar.setValue(self.progress)
self.ui.PredSegProgressBar.setValue(self.progress)
self.ui.PredSegLabel.setText(f"Segmented structures : {self.progress} / {self.total_seg_progress}")
@@ -794,7 +881,7 @@ def UpdateProgressBar(self,progress):


def onProcessUpdate(self,caller,event):

# print(caller.GetProgress(),caller.GetStatus())

# self.ui.TimerLabel.setText(f"Time : {self.startTime:.2f}s")
28 changes: 26 additions & 2 deletions AMASSS/Resources/UI/AMASSS.ui
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
<item>
<widget class="QLabel" name="label_5">
<property name="text">
<string>Input Type</string>
<string>Input Modality</string>
</property>
</widget>
</item>
@@ -68,7 +68,7 @@
<item>
<widget class="QLabel" name="label_folder_select">
<property name="text">
<string>Scan's folder</string>
<string>Select directory</string>
</property>
</widget>
</item>
@@ -539,6 +539,17 @@
</property>
</widget>
</item>
<item>
<layout class="QVBoxLayout" name="verticalLayout_5">
<item>
<widget class="QLabel" name="label_CPU_use">
<property name="text">
<string>Use CPU memory</string>
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="QLabel" name="label_4">
<property name="text">
@@ -622,6 +633,16 @@
</property>
</widget>
</item>
<item>
<widget class="QCheckBox" name="host_memory">
<property name="toolTip">
<string>Use host memory with a performance penalty. Enable if CUDA runs out of memory.</string>
</property>
<property name="text">
<string/>
</property>
</widget>
</item>
<item>
<widget class="QSlider" name="horizontalSliderCPU">
<property name="minimum">
@@ -684,6 +705,9 @@
</property>
</widget>
</item>
<item>
<layout class="QVBoxLayout" name="verticalLayout_10"/>
</item>
<item>
<widget class="QSpinBox" name="spinBoxCPU">
<property name="minimum">
108 changes: 38 additions & 70 deletions AMASSS_CLI/AMASSS_CLI.py
Original file line number Diff line number Diff line change
@@ -21,58 +21,28 @@
import sys
import platform

# try:
# import argparse
# except ImportError:
# pip_install('argparse')
# import argparse


# print(sys.argv)

import torch
import dicom2nifti
import itk
import cc3d

from slicer.util import pip_install, pip_uninstall

# from slicer.util import pip_uninstall
# # pip_uninstall('torch torchvision torchaudio')
import SimpleITK as sitk
import vtk
import numpy as np

# pip_uninstall('monai')

# try :
# import logic
# try:
# import torch
# except ImportError:
# if platform.system() == "Windows":
# pip_install('torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -q')
# else:
# pip_install('torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 -q')
# import torch


try:
import torch
except ImportError:
if platform.system() == "Windows":
pip_install('torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -q')
else:
pip_install('torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 -q')
import torch

try:
import nibabel
except ImportError:
pip_install('nibabel -q')
import nibabel

try:
import einops
except ImportError:
pip_install('einops -q')
import einops

try:
import dicom2nifti
except ImportError:
pip_install('dicom2nifti -q')
import dicom2nifti

#region try import
pip_uninstall('monai -q')
pip_install('monai==0.7.0 -q')
from monai.networks.nets import UNETR

from monai.data import (
@@ -92,26 +62,8 @@

from monai.inferers import sliding_window_inference

import SimpleITK as sitk
import vtk
import numpy as np
try :
import itk
except ImportError:
pip_install('itk -q')
import itk


try:
import cc3d
except ImportError:
pip_install('connected-components-3d==3.9.1 -q')
import cc3d

#endregion



# pip_install('connected-components-3d==3.9.1 -q') #Could connected-components-3d be replaced with itk.connected_component_image_filter or itk.scalar_connected_component_image_filter

# endregion

#region Global variables
@@ -952,7 +904,8 @@ def main(args):

prediction_segmentation = {}


#Get as much memory as possible by cleaning the cache before the 2nd loop
torch.cuda.empty_cache()

for model_id,model_path in models_to_use.items():

@@ -972,9 +925,23 @@ def main(args):
net.load_state_dict(torch.load(model_path,map_location=DEVICE))
net.eval()


val_outputs = sliding_window_inference(input_img, cropSize, args["nbr_GPU_worker"], net,overlap=args["precision"])

## Should avoid error "CUDA OUT OF MEMORY"
# thanks to sw_device = DEVICE, device=torch.device('cpu') - see the documentation of sliding_window_inference

if args["host_memory"]=="True":
device_memory = torch.device('cpu')
else:
device_memory = DEVICE

try:
val_outputs = sliding_window_inference(input_img, cropSize, args["nbr_GPU_worker"], net,overlap=args["precision"],
sw_device= DEVICE, device=device_memory)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
print("Error: CUDA out of memory. You can try running again by enabling CPU usage.")
else:
raise

pred_data = torch.argmax(val_outputs, dim=1).detach().cpu().type(torch.int16)

segmentations = pred_data.permute(0,3,2,1)
@@ -1006,7 +973,8 @@ def main(args):
sys.stdout.flush()
time.sleep(0.5)


# Clear the cache of GPU memory after the loop
torch.cuda.empty_cache()
#endregion

# print(f"""<filter-progress>{1}</filter-progress>""")
@@ -1147,7 +1115,7 @@ def main(args):
"vtk_smooth": int(sys.argv[10]),
"prediction_ID": sys.argv[11],
"nbr_GPU_worker": int(sys.argv[12]),
"nbr_CPU_worker": int(sys.argv[13]),
"host_memory": sys.argv[13],
"temp_fold" : sys.argv[14],
"isSegmentInput" : sys.argv[15] == "true",
"isDCMInput": sys.argv[16] == "true",
10 changes: 5 additions & 5 deletions AMASSS_CLI/AMASSS_CLI.xml
Original file line number Diff line number Diff line change
@@ -99,12 +99,12 @@
<description>Number of GPU to use</description>
</integer>

<integer>
<name>cpu_usage</name>
<label>cpu_usage</label>
<boolean>
<name>host_memory</name>
<label>host_memory</label>
<index>12</index>
<description>Number of CPU to use</description>
</integer>
<description>Switch to CPU memory if True</description>
</boolean>


<string>
1 change: 1 addition & 0 deletions Testing/Temporary/CTestCostData.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
---
3 changes: 3 additions & 0 deletions Testing/Temporary/LastTest.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Start testing: Jan 11 08:38 EST
----------------------------------------------------------
End testing: Jan 11 08:38 EST