Skip to content

Commit

Permalink
Augment pre-command function deletes
Browse files Browse the repository at this point in the history
  • Loading branch information
JakeWags committed Dec 19, 2023
1 parent 77f3c54 commit d110a07
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 70 deletions.
28 changes: 21 additions & 7 deletions shapeworks_cloud/core/deepssm_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import Enum
from random import shuffle

from shapeworks_cloud.core import models


class DeepSSMFileType(Enum):
ID = 1
Expand All @@ -13,7 +15,9 @@ class DeepSSMSplitType(Enum):
TEST = 2


def get_list(project, file_type: DeepSSMFileType, split_type: DeepSSMSplitType):
def get_list(
project, file_type: DeepSSMFileType, split_type: DeepSSMSplitType, testing_split: float
):
"""
Get a list of subjects, ids, filenames, or particle filenames based on the given file type and split type.
Expand All @@ -24,21 +28,24 @@ def get_list(project, file_type: DeepSSMFileType, split_type: DeepSSMSplitType):
Returns:
list: A list of subjects, ids, filenames, or particle filenames based on the given file type and split type.
"""
subjects = project.subjects
dataset = models.Dataset.objects.get(id=project.dataset.id)
subjects = list(project.subjects)

num_subjects = len(list(project.subjects))
# make a list of ids (shuffled order of indicies)
ids = shuffle(list(range(len(subjects))))
ids = list(range(num_subjects))
shuffle(ids)
# make a list of strings
output = []
# get start and end indicies based on split values and type
start = 0

# TODO: determine how to get training and testing splits from project model
end = len(subjects) * (100.0 - project.training_split) / 100.0
end = round(num_subjects * (100.0 - testing_split) / 100.0)

# if the spit type is TEST, use the second half of the list (start = end, end = subjects.length)
if split_type == DeepSSMSplitType.TEST:
start = end
end = len(subjects)
end = num_subjects

# NOTE: SINGLE DOMAIN ASSUMPTION
# currently, DeepSSM only supports a single domain
Expand All @@ -48,7 +55,14 @@ def get_list(project, file_type: DeepSSMFileType, split_type: DeepSSMSplitType):
output.append(ids[i])
# if the file type is IMAGE, add the suject filenames to the list
elif file_type == DeepSSMFileType.IMAGE:
output.append(subjects[ids[i]].image_filename)
images = dataset.get_contents('image')
subject = subjects[ids[i]]

print(subject)

print(list(filter(lambda x: x.name == subject.name, images)))

output.append(filter(lambda x: x.name == subject.name, images))
# if the file type is PARTICLE, add the first particle filename to the list
elif file_type == DeepSSMFileType.PARTICLE:
output.append(subjects[ids[i]].particles[0].filename)
Expand Down
108 changes: 108 additions & 0 deletions shapeworks_cloud/core/migrations/0041_get_contents_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Generated by Django 3.2.23 on 2023-12-18 22:19

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
('core', '0040_partial_deepssm_model'),
]

operations = [
migrations.AlterModelOptions(
name='cachedaugmentation',
options={},
),
migrations.AlterModelOptions(
name='cachedaugmentationpair',
options={},
),
migrations.AlterModelOptions(
name='cacheddataloaders',
options={},
),
migrations.AlterModelOptions(
name='cachedexample',
options={},
),
migrations.AlterModelOptions(
name='cachedmodel',
options={},
),
migrations.AlterModelOptions(
name='cachedmodelexamples',
options={},
),
migrations.AlterModelOptions(
name='cachedprediction',
options={},
),
migrations.AlterModelOptions(
name='cachedtensors',
options={},
),
migrations.RemoveField(
model_name='cachedaugmentation',
name='created',
),
migrations.RemoveField(
model_name='cachedaugmentation',
name='modified',
),
migrations.RemoveField(
model_name='cachedaugmentationpair',
name='created',
),
migrations.RemoveField(
model_name='cachedaugmentationpair',
name='modified',
),
migrations.RemoveField(
model_name='cacheddataloaders',
name='created',
),
migrations.RemoveField(
model_name='cacheddataloaders',
name='modified',
),
migrations.RemoveField(
model_name='cachedexample',
name='created',
),
migrations.RemoveField(
model_name='cachedexample',
name='modified',
),
migrations.RemoveField(
model_name='cachedmodel',
name='created',
),
migrations.RemoveField(
model_name='cachedmodel',
name='modified',
),
migrations.RemoveField(
model_name='cachedmodelexamples',
name='created',
),
migrations.RemoveField(
model_name='cachedmodelexamples',
name='modified',
),
migrations.RemoveField(
model_name='cachedprediction',
name='created',
),
migrations.RemoveField(
model_name='cachedprediction',
name='modified',
),
migrations.RemoveField(
model_name='cachedtensors',
name='created',
),
migrations.RemoveField(
model_name='cachedtensors',
name='modified',
),
]
44 changes: 28 additions & 16 deletions shapeworks_cloud/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,30 @@ class Dataset(TimeStampedModel, models.Model):
contributors = models.TextField(blank=True, default='')
publications = models.TextField(blank=True, default='')

def get_contents(self):
def get_contents(self, type='all'):
ret = []

def truncate_filename(filename):
return filename.split('/')[-1]

for shape_group in [
Segmentation.objects.filter(subject__dataset=self),
Mesh.objects.filter(subject__dataset=self),
Image.objects.filter(subject__dataset=self),
Contour.objects.filter(subject__dataset=self),
]:
if type == 'all':
group_list = [
Segmentation.objects.filter(subject__dataset=self),
Mesh.objects.filter(subject__dataset=self),
Image.objects.filter(subject__dataset=self),
Contour.objects.filter(subject__dataset=self),
]
elif type == 'shape':
group_list = [
Segmentation.objects.filter(subject__dataset=self),
Mesh.objects.filter(subject__dataset=self),
]
elif type == 'image':
group_list = [Image.objects.filter(subject__dataset=self)]
elif type == 'contour':
group_list = [Contour.objects.filter(subject__dataset=self)]

for shape_group in group_list:
for shape in shape_group:
ret.append(
{
Expand Down Expand Up @@ -108,24 +120,24 @@ class CachedAnalysis(TimeStampedModel, models.Model):
good_bad_angles = models.JSONField(default=list)


class CachedPrediction(TimeStampedModel, models.Model):
class CachedPrediction(models.Model):
particles = S3FileField()


class CachedExample(TimeStampedModel, models.Model):
class CachedExample(models.Model):
train_particles = S3FileField()
train_scalars = S3FileField()
validation_particles = S3FileField()
validation_scalars = S3FileField()


class CachedModelExamples(TimeStampedModel, models.Model):
class CachedModelExamples(models.Model):
best = models.ForeignKey(CachedExample, on_delete=models.CASCADE, related_name='best')
median = models.ForeignKey(CachedExample, on_delete=models.CASCADE, related_name='median')
worst = models.ForeignKey(CachedExample, on_delete=models.CASCADE, related_name='worst')


class CachedModel(TimeStampedModel, models.Model):
class CachedModel(models.Model):
configuration = (
S3FileField()
) # this is a json file but we aren't reading it, only passing the path to DeepSSMUtils
Expand All @@ -143,25 +155,25 @@ class CachedModel(TimeStampedModel, models.Model):
final_model_ft = S3FileField(null=True)


class CachedTensors(TimeStampedModel, models.Model):
class CachedTensors(models.Model):
train = S3FileField()
validation = S3FileField()
test = S3FileField()


class CachedDataLoaders(TimeStampedModel, models.Model):
class CachedDataLoaders(models.Model):
mean_pca = S3FileField()
std_pca = S3FileField()
test_names = S3FileField() # this is a .txt file but we only pass the path to DeepSSMUtils
tensors = models.ForeignKey(CachedTensors, on_delete=models.CASCADE, related_name='tensors')


class CachedAugmentationPair(TimeStampedModel, models.Model):
class CachedAugmentationPair(models.Model):
file = S3FileField()
particles = S3FileField()


class CachedAugmentation(TimeStampedModel, models.Model):
class CachedAugmentation(models.Model):
pairs = models.ManyToManyField(CachedAugmentationPair)
total_data_csv = S3FileField() # CSV file but we don't read, only pass to DeepSSMUtils
violin_plot = S3FileField() # PNG file used for plot visualization
Expand Down Expand Up @@ -220,7 +232,7 @@ def get_download_paths(self):
'mesh': [(m.anatomy_type, m.file) for m in subject.meshes.all()],
'segmentation': [(s.anatomy_type, s.file) for s in subject.segmentations.all()],
'contour': [(c.anatomy_type, c.file) for c in subject.contours.all()],
'image': [(i.anatomy_type, i.file) for i in subject.images.all()],
'image': [(i.modality, i.file) for i in subject.images.all()],
'constraints': [(c.anatomy_type, c.file) for c in subject.constraints.all()],
'landmarks': [
(lm.anatomy_type, lm.file)
Expand Down
Loading

0 comments on commit d110a07

Please sign in to comment.