diff --git a/hug_model.py b/hug_model.py index 29e329f..7d3ee60 100644 --- a/hug_model.py +++ b/hug_model.py @@ -6,7 +6,6 @@ def symlink_model(data_dir, model_path): # Creating a symbolic link from destination to "model.bin" - data_dir = '.' model_bin = os.path.join(data_dir, "model.bin") if os.path.isfile(model_bin): os.remove(model_bin) # remove the existing link if any @@ -52,7 +51,7 @@ def download_file(url, destination, params): print('.', end='', flush=True) total_downloaded = 0 print("\nDownload complete.") - + symlink_model(params['datadir'], destination) else: print(f"Download failed with status code {response.status_code}") @@ -76,7 +75,7 @@ def get_user_choice(model_list): print("Invalid input. Please enter a number corresponding to a model.") except IndexError: print("Invalid choice. Index out of range.") - + return None def main(): @@ -96,7 +95,7 @@ def main(): help='HuggingFace model repository filename substring match') parser.add_argument('-d', '--datadir', type=str, default='/data', help='Data directory to store HuggingFace models') - + # Parse the arguments args = parser.parse_args()