Skip to content

Commit

Permalink
wip refactor main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
luiztauffer committed Oct 13, 2023
1 parent b9d9670 commit 44f9fff
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 28 deletions.
2 changes: 1 addition & 1 deletion containers/Dockerfile.ks2_5
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Spike sorters image
FROM spikeinterface/kilosort2_5-compiled-base:0.2.0 as ks25base
FROM spikeinterface/kilosort2_5-compiled-base:0.2.0

# # NVIDIA-ready Image
# FROM nvidia/cuda:11.6.2-base-ubuntu20.04
Expand Down
3 changes: 2 additions & 1 deletion containers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ Basic infrastructure makes use of the following AWS services:

Build docker image:
```bash
$ DOCKER_BUILDKIT=1 docker build -t <image-name:version> -f <Dockerfile_name> .
$ DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/si-sorting-ks25:latest -f Dockerfile.ks2_5 .
$ DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/si-sorting-ks3:latest -f Dockerfile.ks3 .
```

Run locally:
Expand Down
68 changes: 44 additions & 24 deletions containers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,10 @@ def main(
run_at = validate_not_none(run_kwargs, "run_at")
run_identifier = run_kwargs.get("run_identifier", datetime.now().strftime("%Y%m%d%H%M%S"))
run_description = run_kwargs.get("run_description", "")
test_with_toy_recording = run_kwargs.get("test_with_toy_recording", "False").lower() in ("true", "1", "t")
test_with_subrecording = run_kwargs.get("test_with_subrecording", "False").lower() in ("true", "1", "t")
test_with_toy_recording = run_kwargs.get("test_with_toy_recording", "False")
test_with_subrecording = run_kwargs.get("test_with_subrecording", "False")
test_subrecording_n_frames = int(run_kwargs.get("test_subrecording_n_frames", 30000))
log_to_file = run_kwargs.get("log_to_file", "False").lower() in ("true", "1", "t")
log_to_file = run_kwargs.get("log_to_file", "False")

# Get source data kwargs from ENV variables
source_data_kwargs = json.loads(os.environ.get("SI_SOURCE_DATA_KWARGS", "{}"))
Expand Down Expand Up @@ -622,27 +622,47 @@ def main(
output_destination = validate_not_none(output_kwargs, "output_destination")
output_path = validate_not_none(output_kwargs, "output_path")

# Run main function
main(
run_at=run_at,
run_identifier=run_identifier,
run_description=run_description,
test_with_toy_recording=test_with_toy_recording,
test_with_subrecording=test_with_subrecording,
test_subrecording_n_frames=test_subrecording_n_frames,
log_to_file=log_to_file,
source_name=source_name,
source_data_type=source_data_type,
source_data_paths=source_data_paths,
recording_kwargs=recording_kwargs,
preprocessing_kwargs=preprocessing_kwargs,
sorter_kwargs=sorter_kwargs,
postprocessing_kwargs=postprocessing_kwargs,
curation_kwargs=curation_kwargs,
visualization_kwargs=visualization_kwargs,
output_destination=output_destination,
output_path=output_path,
)
# # Run main function
# main(
# run_at=run_at,
# run_identifier=run_identifier,
# run_description=run_description,
# test_with_toy_recording=test_with_toy_recording,
# test_with_subrecording=test_with_subrecording,
# test_subrecording_n_frames=test_subrecording_n_frames,
# log_to_file=log_to_file,
# source_name=source_name,
# source_data_type=source_data_type,
# source_data_paths=source_data_paths,
# recording_kwargs=recording_kwargs,
# preprocessing_kwargs=preprocessing_kwargs,
# sorter_kwargs=sorter_kwargs,
# postprocessing_kwargs=postprocessing_kwargs,
# curation_kwargs=curation_kwargs,
# visualization_kwargs=visualization_kwargs,
# output_destination=output_destination,
# output_path=output_path,
# )

print("\nRun at: ", run_at)
print("\nRun identifier: ", run_identifier)
print("\nRun description: ", run_description)
print("\nTest with toy recording: ", test_with_toy_recording)
print("\nTest with subrecording: ", test_with_subrecording)
print("\nTest subrecording n frames: ", test_subrecording_n_frames)
print("\nLog to file: ", log_to_file)
print("\nSource name: ", source_name)
print("\nSource data type: ", source_data_type)
print("\nSource data paths: ", source_data_paths)
print("\nRecording kwargs: ", recording_kwargs)
print("\nPreprocessing kwargs: ", preprocessing_kwargs)
print("\nSorter kwargs: ", sorter_kwargs)
print("\nPostprocessing kwargs: ", postprocessing_kwargs)
print("\nCuration kwargs: ", curation_kwargs)
print("\nVisualization kwargs: ", visualization_kwargs)
print("\nOutput destination: ", output_destination)
print("\nOutput path: ", output_path)



# Known issues:
Expand Down
File renamed without changes.
File renamed without changes.
13 changes: 11 additions & 2 deletions rest/clients/local_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
)


map_sorter_to_image = {
"kilosort2": "ghcr.io/catalystneuro/si-sorting-ks2:latest",
"kilosort25": "ghcr.io/catalystneuro/si-sorting-ks25:latest",
"kilosort3": "ghcr.io/catalystneuro/si-sorting-ks3:latest",
"ironclust": "ghcr.io/catalystneuro/si-sorting-ironclust:latest",
"spykingcircus": "ghcr.io/catalystneuro/si-sorting-spyking-circus:latest",
}


class LocalDockerClient:

def __init__(self, base_url: str = "tcp://docker-proxy:2375"):
Expand Down Expand Up @@ -57,8 +66,8 @@ def run_sorting(

container = self.client.containers.run(
name=f'si-sorting-run-{run_kwargs.run_identifier}',
image='python:slim',
command=['python', '-c', 'import os; print(os.environ.get("SI_RUN_KWARGS"))'],
image=map_sorter_to_image[sorter_kwargs.sorter_name],
command=['python', 'main.py'],
detach=True,
environment=env_vars,
volumes=volumes,
Expand Down

0 comments on commit 44f9fff

Please sign in to comment.