Skip to content

Commit

Permalink
Add main scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangfanmark committed Oct 15, 2020
0 parents commit c7624f0
Show file tree
Hide file tree
Showing 9 changed files with 2,335 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.pyc
TestData*/
SegModels*/
164 changes: 164 additions & 0 deletions dlt_extract_tract_feat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import src.tract_feat as tract_feat

import whitematteranalysis as wma
import numpy as np

import argparse
import os
import h5py

#-----------------
# Parse arguments
#-----------------
parser = argparse.ArgumentParser(
description="Compute FiberMap of input vtk file.",
epilog="Written by Fan Zhang, [email protected]")

parser.add_argument(
'inputVTK',
help='input tractography data as vtkPolyData file(s).')
parser.add_argument(
'outputDir',
help='The output directory should be a new empty directory. It will be created if needed.')

parser.add_argument(
'-outPrefix',type=str,
help='A prefix string of all output files.')

parser.add_argument(
'-feature', action="store", type=str,
help="Name of feature. Currently support: `RAS`")
parser.add_argument(
# TODO: RNN to get rid for this.
'-numPoints', action="store", type=int, default=15,
help='Number of points per fiber to extract feature.')
parser.add_argument(
# TODO: RNN to get rid for this.
'-numRepeats', action="store", type=int, default=15,
help='Number of repiteation times.')

parser.add_argument(
'-downsampleStep', action="store", type=int,
help='Downsample the input')
parser.add_argument(
'-groundTruthLabel', action="store", type=str,
help='Path to the ground truth label file. Should be provided when downsample is used.')

args = parser.parse_args()

script_name = '<extract_tract_feat>'

if not os.path.exists(args.inputVTK):
print(script_name, "Error: Input tractography ", args.inputVTK, "does not exist.")
exit()

if not os.path.exists(args.outputDir):
print(script_name, "Output directory", args.outputDir, "does not exist, creating it.")
os.makedirs(args.outputDir)

print(script_name, 'Reading input tractography:', args.inputVTK)
pd_tract = wma.io.read_polydata(args.inputVTK)

print(script_name, 'Computing feauture:', args.feature)
if args.feature == 'RAS':

feat_RAS = tract_feat.feat_RAS(pd_tract, number_of_points=args.numPoints)

# Reshape from 3D (num of fibers, num of points, num of features) to 4D (num of fibers, num of points, num of features, 1)
# The 4D array considers the input has only one channel (depth = 1)
feat_shape = np.append(feat_RAS.shape, 1)
feat = np.reshape(feat_RAS, feat_shape)

if args.feature == 'Orientation-3D':

feat_orient = tract_feat.feat_orientation_3D(pd_tract, number_of_points=args.numPoints, repeat_time=args.numPoints)

feat = feat_orient

elif args.feature == 'RAS-3D':

feat_RAS_3D = tract_feat.feat_RAS_3D(pd_tract, number_of_points=args.numPoints, repeat_time=args.numRepeats)

feat = feat_RAS_3D

elif args.feature == 'RASF':

feat_RAS_FS = tract_feat.feat_RASF(pd_tract, number_of_points=args.numPoints)

# Reshape from 3D (num of fibers, num of points, num of features) to 4D (num of fibers, num of points, num of features, 1)
# The 4D array considers the input has only one channel (depth = 1)
feat_shape = np.append(feat_RAS_FS.shape, 1)
feat = np.reshape(feat_RAS_FS, feat_shape)

elif args.feature == 'RASF-3D':

feat_RAS_FS = tract_feat.feat_RASF_3D(pd_tract, number_of_points=args.numPoints)

feat = feat_RAS_FS

elif args.feature == 'RAS-1D':

feat_RAS_1D = tract_feat.feat_1D(pd_tract, number_of_points=args.numPoints)

feat_shape = np.append(feat_RAS_1D.shape, 1)
feat_shape = np.append(feat_shape, 1)

feat = np.reshape(feat_RAS_1D, feat_shape)

elif args.feature == 'RASCurvTors':

feat_curv_tors = tract_feat.feat_RAS_curv_tors(pd_tract, number_of_points=args.numPoints)

feat_shape = np.append(feat_curv_tors.shape, 1)

feat = np.reshape(feat_curv_tors, feat_shape)

elif args.feature == 'CurvTors':

feat_curv_tors = tract_feat.feat_curv_tors(pd_tract, number_of_points=args.numPoints)

feat_shape = np.append(feat_curv_tors.shape, 1)

feat = np.reshape(feat_curv_tors, feat_shape)

print(script_name, 'Feature matrix shape:', feat.shape)

if args.groundTruthLabel is not None:
with h5py.File(args.groundTruthLabel, "r") as f:
label_array = f['label_array'].value.astype(int)
label_values = f['label_values'].value
label_names = f['label_names'].value
# print script_name, 'Input label_names:'
# print label_names
else:
label_array = None
label_values = None
label_names = None

## downsampling
if args.downsampleStep is not None:
print(script_name, 'Downsampling the feature matrix with step size:', args.downsampleStep)

feat, label_array = tract_feat.downsample(args.downsampleStep, feat, label_array)

print(script_name, 'Feature matrix shape (downsampled):', feat.shape)
print(script_name, 'Label array shape (downsampled):', label_array.shape if label_array is not None else label_array)

## Save feat
with h5py.File(os.path.join(args.outputDir, args.outPrefix+'_featMatrix.h5'), "w") as f:
f.create_dataset('feat',data=feat)

print(script_name, 'Feature matrix shape:', feat.shape)

## Save label
if args.groundTruthLabel is not None:
with h5py.File(os.path.join(args.outputDir, args.outPrefix+'_label.h5'), "w") as f:
f.create_dataset('label_array', data=label_array)
f.create_dataset('label_values', data=label_values)
f.create_dataset('label_names', data=label_names)

print(script_name, 'Ground truth shape:', label_array.shape)
print(script_name, 'Ground truth label names', label_names)

print(script_name, 'Done! Find results in:', args.outputDir)

225 changes: 225 additions & 0 deletions dlt_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import src.tract_feat as tract_feat
import src.nn_model as nn_model

import whitematteranalysis as wma
import numpy as np

import argparse
import os
import h5py

import keras
from keras.models import load_model

from sklearn.metrics import classification_report
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.utils import class_weight
from sklearn.metrics import confusion_matrix

import os

CPU = True
if CPU:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import tensorflow as tf

num_cores = 4

config = tf.ConfigProto(intra_op_parallelism_threads=num_cores,
inter_op_parallelism_threads=num_cores,
allow_soft_placement=True,
device_count = {'CPU' : 4})

session = tf.Session(config=config)

tf.keras.backend.set_session(session)

keras.backend.set_session(session)
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

#-----------------
# Parse arguments
#-----------------
parser = argparse.ArgumentParser(
description="Testing using a CNN model.",
epilog="Written by Fan Zhang, [email protected]")

parser.add_argument(
'inputModel',
help='Pretrained network model as an h5 file.')
parser.add_argument(
'inputFeat',
help='Input tract feature data as an h5 file.')
parser.add_argument(
'outputDir',
help='The output directory should be a new empty directory. It will be created if needed.')

parser.add_argument(
'-modelLabelName',type=str,
help='Label name in the model as an h5 file.')

parser.add_argument(
'-inputLabel',type=str,
help='Input ground truth label as an h5 file.')

parser.add_argument(
'-outPrefix',type=str,
help='A prefix string of all output files.')

parser.add_argument(
'-tractVTKfile',type=str,
help='Tractography data as a vtkPolyData file. If given, prediction will output tracts')

args = parser.parse_args()

script_name = '<test>'

if not os.path.exists(args.inputModel):
print(script_name, "Error: Input network model ", args.inputModel, "does not exist.")
exit()

if not os.path.exists(args.modelLabelName):
print(script_name, "Error: Input model label name", args.modelLabelName, "does not exist.")
exit()

if not os.path.exists(args.inputFeat):
print(script_name, "Error: Input feature ", args.inputFeat, "does not exist.")
exit()

if args.inputLabel is None:
print(script_name, "No input label is provided. Will perform prediction only.")
elif not os.path.exists(args.inputLabel):
print(script_name, "Error: Input label ", args.inputLabel, "does not exist.")
exit()

if not os.path.exists(args.outputDir):
print(script_name, "Output directory", args.outputDir, "does not exist, creating it.")
os.makedirs(args.outputDir)


''' Load data '''

# Load model parameters
print(script_name, 'Load parameters when training the model.')
params = np.load(args.inputModel.replace('_model.h5', '_params.npy'), allow_pickle=True).item(0)

# Load label names in the model
print(script_name, 'Load tracts names along with the model.')
with h5py.File(args.modelLabelName, "r") as f:
y_names_in_model = f['y_names'].value

# Load test data feature
with h5py.File(args.inputFeat, "r") as f:
print(script_name, 'Load input feature.')
x_test = f['feat'].value

# Generate ground truth labels for evaluation
if args.inputLabel is not None:

print(script_name, 'Load input label.')
with h5py.File(args.inputLabel, "r") as f:
y_test = f['label_array'].value.astype(int)
# y_value = f['label_values'].value
y_names = f['label_names'].value

# Used for generate ground truth label
y_test_orig = y_test.copy()
y_names_orig = y_names.copy()

# Generate final ground truth label
print(script_name, 'Generate FINAL ground truth label for evaluation.')

print(script_name, ' # Feat Preprocessing - combine subdiviations of some tracts including CBLM, SupT, and Others.')

y_test, y_names, _ = tract_feat.combine_tract_subdiviations_and_merge_outliers(y_test, y_names, verbose=False)

if params['bilateral_feature']:
y_test, y_names, _ = tract_feat.bilateralize_feature(y_test, y_names, verbose=False)

y_test_ground_truth_final = tract_feat.update_y_test_based_on_model_y_names(y_test, y_names, y_names_in_model)

else:
y_test_ground_truth_final = None


if params['bilateral_feature']:
print(script_name, 'Make a bilateral copy for each fiber.')
x_test, _ = tract_feat.bilateral_X_data(x_test)


# Perform predition of multiple tracts

print('')
print('===================================')
print('')
print(script_name, 'Start multi-tract prediction.')

print(script_name, 'x_test shape:', x_test.shape)
print(script_name, 'tracts to predict:', y_names_in_model)
#print script_name, 'tracts in the input data', y_names


output_multi_tract_predition_mask_path = os.path.join(args.outputDir, args.outPrefix+'_multi_tract_specific_prediction_mask.h5')
output_multi_tract_predition_report_path = os.path.join(args.outputDir, args.outPrefix+'_multi_tract_prediction_report.h5')
if not os.path.exists(output_multi_tract_predition_mask_path):

# Load model
model = load_model(args.inputModel)

y_prediction, prediction_report, con_matrix = nn_model.predict(model, x_test, y_data=y_test_ground_truth_final, y_name=y_names_in_model, verbose=True)

if args.inputLabel is not None:
if prediction_report is not None:
with h5py.File(output_multi_tract_predition_report_path, "w") as f:
f.create_dataset('prediction_report',data=prediction_report)
f.create_dataset('con_matrix',data=con_matrix)

with h5py.File(output_multi_tract_predition_mask_path, "w") as f:
f.create_dataset('y_prediction',data=y_prediction)

del model

else:
print(script_name, 'Loading prediction result.')
with h5py.File(output_multi_tract_predition_mask_path, "r") as f:
y_prediction = f['y_prediction'].value

if args.tractVTKfile is not None:

print('')
print('===================================')
print('')
print(script_name, 'Output fiber tracts.')

tract_prediction_mask = y_prediction

print(script_name, 'Load vtk:', args.tractVTKfile)
pd_whole_tract = wma.io.read_polydata(args.tractVTKfile)

print(script_name, ' # labels in mask:', np.unique(tract_prediction_mask))
print(script_name, ' # y_names:', y_names_in_model)

number_of_tracts = np.max(tract_prediction_mask) + 1
pd_t_list = wma.cluster.mask_all_clusters(pd_whole_tract, tract_prediction_mask, number_of_tracts, preserve_point_data=False, preserve_cell_data=False, verbose=False)

output_tract_folder = os.path.join(args.outputDir, args.outPrefix+'_prediction_tracts_outlier_removed')
if not os.path.exists(output_tract_folder):
os.makedirs(output_tract_folder)

for t_idx in range(len(pd_t_list)):
pd_t = pd_t_list[t_idx]

if y_names_in_model is not None:
fname_t = os.path.join(output_tract_folder, y_names_in_model[t_idx].decode('UTF-8')+'.vtp')
else:
fname_t = os.path.join(output_tract_folder, 'tract_'+str(t_idx)+'.vtp')

print(script_name, 'output', fname_t)
wma.io.write_polydata(pd_t, fname_t)

print(script_name, 'Done! Tracts are in:', output_tract_folder)
Loading

0 comments on commit c7624f0

Please sign in to comment.