Skip to content

Commit

Permalink
Download iso chunk folder (#470)
Browse files Browse the repository at this point in the history
* add more info on what chunk is downloading, and make chunk folder foe each file

* fix bug

* comments

* add .cache/sparsezoo/neuralmagic/
  • Loading branch information
horheynm authored Feb 29, 2024
1 parent 3b97dfc commit b463e83
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions src/sparsezoo/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import shutil
import threading
from dataclasses import dataclass, field
from pathlib import Path
from queue import Queue
from typing import Any, Callable, Dict, Optional

Expand Down Expand Up @@ -93,13 +94,26 @@ def __init__(
self.chunk_bytes = chunk_bytes
self.job_queues = Queue()
self._lock = threading.Lock()
self.chunk_folder = self.get_parent_chunk_folder(download_path)

def get_parent_chunk_folder(self, path: str) -> str:
"""Get the name of the file that is used as folder inside chunks"""
path = path.split(os.path.sep)[-1]
path = path.replace(".", "_")
return path
self.chunk_download_path = self.get_chunk_download_path(download_path)

def get_chunk_download_path(self, path: str) -> str:
"""Get the path where chunks will be downloaded"""

# make the folder name from the model name and file to be downloaded
stub = path.split(os.path.sep)[-3]
path = "_".join(path.split(os.path.sep)[-2:])
file_name_as_folder = path.replace(".", "_")

# save the chunks on a different folder than the root model folder
return os.path.join(
str(Path.home()),
".cache",
"sparsezoo",
"neuralmagic",
"chunks",
stub,
file_name_as_folder,
)

def is_range_header_supported(self) -> bool:
"""Check if chunck download is supported"""
Expand Down Expand Up @@ -171,8 +185,10 @@ def queue_chunk_download_jobs(self) -> None:
bytes_range = f"bytes={start_byte}-{end_byte}"

func_kwargs = {
"download_path": self.get_chunk_file_path(
os.path.join(self.chunk_folder, f"{job_id:05d}_{bytes_range}")
"download_path": (
os.path.join(
self.chunk_download_path, f"{job_id:05d}_{bytes_range}"
)
),
"headers": {
"Range": bytes_range,
Expand Down Expand Up @@ -376,12 +392,10 @@ def combine_chunks_and_delete(self, download_path: str, progress_bar: tqdm) -> N
:param progress_bar: tqdm object showing the progress of combining chunks
"""
parent_directory = os.path.dirname(download_path)
chunk_directory = os.path.join(parent_directory, "chunks", self.chunk_folder)
_LOGGER.debug("Combing and deleting ", chunk_directory)
_LOGGER.debug("Combing and deleting ", self.chunk_download_path)

pattern = re.compile(r"\d+_bytes=")
files = os.listdir(chunk_directory)
files = os.listdir(self.chunk_download_path)

chunk_files = [chunk_file for chunk_file in files if pattern.match(chunk_file)]

Expand All @@ -390,13 +404,13 @@ def combine_chunks_and_delete(self, download_path: str, progress_bar: tqdm) -> N
create_parent_dirs(self.download_path)
with open(self.download_path, "wb") as combined_file:
for file_path in sorted_chunk_files:
chunk_path = os.path.join(chunk_directory, file_path)
chunk_path = os.path.join(self.chunk_download_path, file_path)
with open(chunk_path, "rb") as infile:
data = infile.read()
combined_file.write(data)
progress_bar.update(len(data))

shutil.rmtree(chunk_directory)
shutil.rmtree(self.chunk_download_path)

def get_chunk_file_path(self, file_range: str) -> str:
"""
Expand Down

0 comments on commit b463e83

Please sign in to comment.